Multi-LoRA SFT support FSDP2 #155
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements FSDP2 support for MultiLoraTransformersModel by integrating it into the shared strategy and lazy-wrap lifecycle and introducing sharding-aware parameter access helpers. Review feedback identifies critical bugs in the distributed tensor handling: _write_param_tensor may incorrectly double-shard local data, set_state_dict risks shape mismatches when applying global state to local shards, and get_state_dict returns sharded tensors that could lead to corrupt checkpoints. Furthermore, the model's initialization should be refactored to properly use the parent class, and internal imports should be moved to the module level.
|
I'd love to have this feature! Just curious — why was this PR changed to draft? Any other plans in the works? |
cause working in progress, when finished, will merge into branch main |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request implements distributed training support for MultiLora, introducing helper methods for tensor sharding and updating model lifecycle methods to handle distributed contexts. The review identified critical issues where local shards are incorrectly processed as global tensors, potentially leading to corrupted weights during sharding and incomplete state dicts. Feedback emphasizes the need to gather tensors before saving or returning them to ensure compatibility with standard loaders and correct distributed behavior.
| return | ||
| value = value.detach().to(dtype=parameter.dtype) | ||
| if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'): | ||
| value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements) |
There was a problem hiding this comment.
The use of distribute_tensor here assumes that value is a global tensor that needs to be sharded according to the device_mesh and placements. However, in several call sites (like set_state_dict and _load_initial_weights), the value passed to _write_param_tensor is derived from _read_param_tensor, which returns a local shard. Calling distribute_tensor on a local shard will incorrectly attempt to shard the shard again, leading to incorrect parameter values in distributed training.
| target_tensor = self._read_param_tensor(parameter) | ||
| if target_tensor is None: | ||
| continue | ||
| target_tensor = target_tensor.clone() | ||
| src_tensor = state_dict[state_key].to(dtype=target_tensor.dtype, device=target_tensor.device) | ||
| self._copy_rank_tensor(name, target_tensor, src_tensor) | ||
| self._write_param_tensor(parameter, target_tensor) |
There was a problem hiding this comment.
There is a mismatch between local and global tensors here. target_tensor is a local shard (from _read_param_tensor), while src_tensor is a global tensor (from state_dict).
_copy_rank_tensor(line 640) will fail with a shape mismatch if the dimension being copied is sharded by FSDP (e.g.,num_embeddingsfor embeddings orout_featuresfor linear layers)._write_param_tensor(line 641) will then attempt to shard this local shard again as discussed in the previous comment.
To fix this, you should either shard src_tensor to match the local shard's placements before copying, or perform the copy on global tensors (on CPU) and then use _write_param_tensor to shard the result.
| _param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r) | ||
| if _param is None: | ||
| continue | ||
| name = name.replace(f'.{_lora.adapter_name}.', '.') | ||
| state_dict[name] = _param |
There was a problem hiding this comment.
In a distributed setting with FSDP2, _read_param_tensor returns a local shard. Consequently, get_state_dict returns a dictionary of sharded tensors. Since this method is decorated with @remote_function(collect='first') in the model class, only rank 0's local shards will be returned to the caller. This results in an incomplete and unusable state dict for the LoRA adapter. You should gather the shards into a global tensor before slicing and returning them.
| _param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r) | ||
| if _param is not None: | ||
| _param = _param.clone() | ||
| name = name.replace(f'.{_lora.adapter_name}.', '.') | ||
| return name, _param |
There was a problem hiding this comment.
Similar to get_state_dict, save_lora_converter is returning a local shard of the parameter. When saving the model, this will result in a checkpoint containing sharded LoRA weights, which is incompatible with standard PEFT loaders. LoRA weights are typically small enough to be gathered and saved as full tensors even in FSDP environments.
PR type
PR information
Multi-LoRA support FSDP2