Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
747769b
feat(fsdp2): add _broadcast_sharded_state_dict, _get_non_persistent_b…
kevssim Mar 18, 2026
a69fb6c
feat(fsdp2): enable cpu_ram_efficient_loading for both strategies; pa…
kevssim Mar 18, 2026
c015e13
refactor(fsdp2): use _non_persistent_buffers_set for precise non-pers…
kevssim Mar 18, 2026
5fa4c4b
Merge branch 'optimize_fsdp_init' of https://github.com/kevssim/twink…
kevssim Mar 18, 2026
587f001
wip
kevssim Mar 18, 2026
983cdbc
test(fsdp2): make tests platform-agnostic (cuda/npu) via Platform API
kevssim Mar 18, 2026
d5d2832
fix(test): pass inputs as List[InputFeature] to forward_backward
kevssim Mar 18, 2026
5973173
fix(test): add position_ids to e2e test batch
kevssim Mar 18, 2026
accd03b
fix(test): simplify e2e test to only verify wrap_model, avoid process…
kevssim Mar 18, 2026
2c72aa4
fix(fsdp2): handle non-DTensor params (e.g. tied weights) in _broadca…
kevssim Mar 19, 2026
bf0e155
fix(fsdp2): move remaining CPU/meta params to device after tie_weights
kevssim Mar 19, 2026
e35a3d9
debug: add verbose logging to e2e test to diagnose CPU param issue
kevssim Mar 19, 2026
60c4a3a
debug: add verbose logging to wrap_model to trace execution path
kevssim Mar 19, 2026
93ef7e4
debug: add verbose logging to _lazy_wrap_model
kevssim Mar 19, 2026
638a996
debug: add device_mesh check to e2e test
kevssim Mar 19, 2026
a61db10
debug: print mesh before TransformersModel init
kevssim Mar 19, 2026
a06e894
fix(test): call twinkle.initialize() before TransformersModel to pres…
kevssim Mar 19, 2026
f8def97
cleanup: remove all debug print statements from native_fsdp.py and tr…
kevssim Mar 19, 2026
13c1d5f
wip
kevssim Mar 19, 2026
0438d9e
lint
kevssim Mar 19, 2026
44bf3d4
fix
kevssim Mar 19, 2026
3b82d1c
wip
kevssim Mar 19, 2026
beaa4fd
wip
kevssim Mar 19, 2026
560eb23
wip
kevssim Mar 19, 2026
d8f39b1
wip
kevssim Mar 19, 2026
cbb6191
lint
kevssim Mar 19, 2026
38e75cd
lint
kevssim Mar 19, 2026
e482625
wip
kevssim Mar 19, 2026
00fd199
clean
kevssim Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/twinkle/model/transformers/strategy/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ def __init__(
mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
ddp_config: Dict[str, Any] = None,
fsdp_config: Dict[str, Any] = None,
memory_efficient: bool = True,
):
from accelerate import Accelerator

self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
parallelism_config = self._parallelism_config_from_device_mesh(device_mesh)
fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config)
fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient)

kwargs_handlers = []
if ddp_config is not None:
Expand Down Expand Up @@ -69,7 +70,8 @@ def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh):

return parallelism_config

def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any]):
def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any],
memory_efficient: bool):
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp import BackwardPrefetch
from torch.distributed.fsdp import ShardingStrategy as FSDPShardingStrategy
Expand Down Expand Up @@ -107,11 +109,9 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di
activation_checkpointing=fsdp_config.pop('activation_checkpointing', False),
auto_wrap_policy=fsdp_config.pop('auto_wrap_policy', 'transformer_based_wrap'), # noqa
reshard_after_forward=fsdp_config.pop('reshard_after_forward', True),
cpu_ram_efficient_loading=fsdp_config.pop('cpu_ram_efficient_loading', memory_efficient),
**fsdp_config,
)
# Enable memory efficient model loading in transformers(see `is_fsdp_enabled` in transformers)
# os.environ['ACCELERATE_USE_FSDP'] = '1'
# os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING'] = '1'
return fsdp_plugin

def wrap_model(self, model, *args):
Expand Down
111 changes: 100 additions & 11 deletions src/twinkle/model/transformers/strategy/native_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,27 @@ def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[
)
return ep_mesh.to_torch_device_mesh()

def wrap_model(self, model, optimizer=None):
def wrap_model(self, model, optimizer=None, memory_efficient=True):
if self.device_mesh is None:
return model, optimizer
fsdp_mesh = _build_fsdp_mesh(self.device_mesh)
if fsdp_mesh is not None:
ep_enabled = (self.enable_ep and self.ep_fsdp_device_mesh is not None)

# EP path requires experts on a real device, incompatible with meta-device flow.
use_meta = memory_efficient and not ep_enabled

original_sd = None
saved_buffers = None
if use_meta:
original_sd = model.state_dict()
saved_buffers = _get_non_persistent_buffers(model)
if optimizer is not None:
_unbind_optimizer_params(optimizer)
model = model.to(torch.device('meta'))
if hasattr(model, 'tie_weights'):
model.tie_weights()

if ep_enabled:
_ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh)
_place_ep_experts_on_local_device(model, self.ep_fsdp_device_mesh)
Expand All @@ -57,19 +72,16 @@ def wrap_model(self, model, optimizer=None):
if ep_enabled:
_ensure_ep_fsdp_supported(model)

# Collect experts map and expert params
experts_map = _collect_ep_experts_map(model) if ep_enabled else {}
expert_params = _collect_expert_params(model) if self.enable_ep else None

# Build layer_pairs: [(layer_mod, experts_mod_or_None)]
layers = _get_decoder_layers(model)
layer_pairs = []
if layers is not None:
for layer_mod in layers:
experts_mod = _find_experts_in_layer(layer_mod, experts_map)
layer_pairs.append((layer_mod, experts_mod))

# FSDP2 wrapping per layer
world_size = self.device_mesh.world_size
ep_fsdp_mesh_1d = self.ep_fsdp_device_mesh['ep_fsdp'] if ep_enabled else None

Expand All @@ -79,9 +91,6 @@ def wrap_model(self, model, optimizer=None):
if experts_mod is not None and ep_fsdp_mesh_1d is not None:
from torch.distributed.tensor import Shard

# PreMulSum (used by set_gradient_divide_factor) only supports
# float16/float32/float64; override reduce_dtype to float32
# when the base policy uses bfloat16.
ep_mp_policy = _build_ep_mp_policy(mp_policy)
fully_shard(
experts_mod,
Expand All @@ -90,7 +99,6 @@ def wrap_model(self, model, optimizer=None):
mp_policy=ep_mp_policy,
shard_placement_fn=lambda param: Shard(1),
)
# gradient_divide_factor = world_size
experts_mod.set_gradient_divide_factor(world_size)
layer_mod._fsdp_modules.append(experts_mod)

Expand All @@ -103,7 +111,6 @@ def wrap_model(self, model, optimizer=None):
)
layer_mod._fsdp_modules.append(layer_mod)

# Root model
fully_shard(
model,
mesh=fsdp_mesh,
Expand All @@ -112,11 +119,22 @@ def wrap_model(self, model, optimizer=None):
ignored_params=expert_params,
)

# Manual prefetch
if use_meta:
device_type = self.device_mesh.device_type or 'cuda'
is_rank0 = (dist.get_rank() == 0)
_broadcast_sharded_state_dict(
model,
original_sd if is_rank0 else {},
device_type=device_type,
)
target_device = torch.device(device_type)
_restore_non_persistent_buffers(model, saved_buffers, device=target_device)
if hasattr(model, 'tie_weights'):
model.tie_weights()

if ep_enabled and layer_pairs:
_setup_manual_prefetch([lp[0] for lp in layer_pairs])

# Rebuild groups after wrapping so grad clip sees the live Parameter objects.
if ep_enabled:
_rebuild_ep_param_groups(model)

Expand Down Expand Up @@ -398,3 +416,74 @@ def _rebind_optimizer(optimizer: torch.optim.Optimizer, model: nn.Module) -> tor
return optimizer
optimizer.param_groups[0]['params'] = list(model.parameters())
return optimizer


def _broadcast_sharded_state_dict(
model: nn.Module,
full_sd: dict,
device_type: str = 'cuda',
) -> None:
"""Broadcast full state dict from rank 0 and materialise local shards via distribute_tensor."""
from torch.distributed.tensor import DTensor, distribute_tensor

meta_sharded_sd = model.state_dict()
sharded_sd = {}
is_rank0 = (dist.get_rank() == 0)

for param_name, sharded_param in meta_sharded_sd.items():
shape = sharded_param.size()
dtype = sharded_param.dtype

if is_rank0:
full_param = full_sd[param_name]
full_tensor = full_param.detach().to(device_type)
if isinstance(full_tensor, DTensor):
full_tensor = full_tensor.to_local()
else:
full_tensor = torch.empty(shape, device=device_type, dtype=dtype)

dist.broadcast(full_tensor, src=0)
torch_util.synchronize()

device_mesh = sharded_param.device_mesh
placements = sharded_param.placements
sharded_tensor = distribute_tensor(full_tensor, device_mesh, placements)

sharded_sd[param_name] = sharded_tensor

model.load_state_dict(sharded_sd, assign=True)


def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]:
"""Return {fqn: tensor} for non-persistent buffers (lost on to('meta'))."""
non_persistent_fqns: Set[str] = set()
for fqn, module in model.named_modules():
for buf_name in getattr(module, '_non_persistent_buffers_set', set()):
full_fqn = f'{fqn}.{buf_name}' if fqn else buf_name
non_persistent_fqns.add(full_fqn)

return {k: v.clone() for k, v in model.named_buffers() if k in non_persistent_fqns}


def _unbind_optimizer_params(optimizer: torch.optim.Optimizer) -> None:
"""Drop optimizer param refs so model.to('meta') can free memory."""
for group in optimizer.param_groups:
for i in range(len(group['params'])):
group['params'][i] = torch.empty(1)


def _restore_non_persistent_buffers(
model: nn.Module,
saved_buffers: Dict[str, torch.Tensor],
device: torch.device,
) -> None:
"""Re-register non-persistent buffers saved before to('meta')."""
for fqn, buf_tensor in saved_buffers.items():
buf_tensor = buf_tensor.to(device)
if '.' in fqn:
parent_fqn, local_name = fqn.rsplit('.', 1)
parent = model.get_submodule(parent_fqn)
else:
local_name = fqn
parent = model
parent.register_buffer(local_name, buf_tensor, persistent=False)
37 changes: 32 additions & 5 deletions src/twinkle/model/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(
ddp_config: Dict[str, Any] = None,
fsdp_config: Dict[str, Any] = None,
grad_scaler_config: Dict[str, Any] = None,
memory_efficient_init: bool = True,
**kwargs):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
self._try_init_process_group()
Expand All @@ -195,6 +196,7 @@ def __init__(
self.mixed_precision = mixed_precision
self._fsdp_config = dict(fsdp_config or {})
self._ddp_config = ddp_config or {}
self._memory_efficient_init = memory_efficient_init
self._decide_strategy(strategy)
self.grad_scaler_config = grad_scaler_config
if isinstance(model_cls, str):
Expand All @@ -203,8 +205,23 @@ def __init__(
self.model = model_cls.from_config(config, **kwargs)
else:
model_id = HubOperation.download_model(model_id)
self.model = model_cls.from_pretrained(model_id, config=config, **kwargs)
# Construct sequence-parallel strategy lazily during wrapping to reduce init-time side effects.
# Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load.
use_efficient_loading = (memory_efficient_init and self.device_mesh is not None)
_saved_env = {}
if use_efficient_loading:
_saved_env['ACCELERATE_USE_FSDP'] = os.environ.get('ACCELERATE_USE_FSDP')
_saved_env['FSDP_CPU_RAM_EFFICIENT_LOADING'] = os.environ.get('FSDP_CPU_RAM_EFFICIENT_LOADING')
os.environ['ACCELERATE_USE_FSDP'] = 'true'
os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING'] = 'true'
try:
self.model = model_cls.from_pretrained(model_id, config=config, **kwargs)
finally:
if use_efficient_loading:
for key, old_val in _saved_env.items():
if old_val is None:
os.environ.pop(key, None)
else:
os.environ[key] = old_val
self.model.gradient_checkpointing_enable()
self.sp_strategy = None
self._model_wrapped = False
Expand Down Expand Up @@ -237,7 +254,8 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']):
mixed_precision=self.mixed_precision,
ddp_config=self._ddp_config,
fsdp_config=self._fsdp_config,
device_mesh=self.device_mesh)
device_mesh=self.device_mesh,
memory_efficient=self._memory_efficient_init)

# Sequence parallel ("ulysses") is derived from dp/fsdp ranks; it does not change world size.
# We construct `sp_strategy` after the underlying HF model is initialized (see __init__).
Expand Down Expand Up @@ -284,16 +302,25 @@ def _lazy_wrap_model(self):
self._ensure_sp_strategy()
if self.sp_strategy is not None:
self.sp_strategy.initialize()

extra_kwargs = {}
if isinstance(self.strategy, NativeFSDPStrategy):
extra_kwargs['memory_efficient'] = getattr(self, '_memory_efficient_init', True)

if len(optimizer_groups) == 1:
optimizer_group = optimizer_groups[0]
optimizer = optimizer_group.optimizer
assert optimizer is not None
self.model, optimizer = self.strategy.wrap_model(self.model, optimizer)
self.model, optimizer = self.strategy.wrap_model(self.model, optimizer, **extra_kwargs)
optimizer_group.optimizer = optimizer
self.register_mm_forward_hook(optimizer_group)
else:
# maybe forward_only, no optimizer_group available
self.model = self.strategy.wrap_model(self.model)
result = self.strategy.wrap_model(self.model, **extra_kwargs)
if isinstance(result, tuple):
self.model = result[0]
else:
self.model = result
self._model_wrapped = True

def register_mm_forward_hook(self, optimizer_group: OptimizerGroup):
Expand Down
Empty file added tests/strategy/__init__.py
Empty file.
Loading
Loading