-
Notifications
You must be signed in to change notification settings - Fork 29
Multi-LoRA SFT support FSDP2 #155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5f05817
b19373f
6f46dc3
2694bdc
ec5f6a9
ee46e60
bd02a96
6c359e0
26cada3
1dff684
f5127c7
af731b5
62c496c
d43f541
10a9dd3
0c1f1cd
a0563d8
8035ff0
fc9019b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| from dataclasses import dataclass, field | ||
| from peft import LoraConfig, PeftModel, get_peft_model | ||
| from peft.tuners.lora import Embedding, Linear, LoraLayer | ||
| from torch.distributed.tensor import distribute_tensor | ||
| from types import MethodType | ||
| from typing import Any, Callable, Dict, List, Optional, Union | ||
|
|
||
|
|
@@ -42,6 +43,49 @@ def _get_available_lora(self) -> Optional[LoraTenant]: | |
| return _lora | ||
| return None | ||
|
|
||
| def _read_param_tensor(self, parameter): | ||
| return torch_util.to_local_tensor(parameter) | ||
|
|
||
| def _write_param_tensor(self, parameter, value): | ||
| if value is None: | ||
| return | ||
| value = value.detach().to(dtype=parameter.dtype) | ||
| if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'): | ||
| value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements) | ||
| else: | ||
| value = value.to(parameter.device) | ||
| parameter.data.copy_(value) | ||
|
kevssim marked this conversation as resolved.
|
||
|
|
||
| @staticmethod | ||
| def _slice_rank_tensor(name: str, tensor, rank: int): | ||
| if tensor is None: | ||
| return None | ||
| if 'embedding_A' in name: | ||
| return tensor[:, :rank] | ||
| if 'embedding_B' in name: | ||
| return tensor[:rank, :] | ||
| if '_A' in name: | ||
| return tensor[:rank, :] | ||
| if '_B' in name: | ||
| return tensor[:, :rank] | ||
| return tensor | ||
|
|
||
| @staticmethod | ||
| def _copy_rank_tensor(name: str, target, value): | ||
| if target is None or value is None: | ||
| return None | ||
| if 'embedding_A' in name: | ||
| target[:, :value.shape[1]].copy_(value) | ||
| elif 'embedding_B' in name: | ||
| target[:value.shape[0], :].copy_(value) | ||
| elif '_A' in name: | ||
| target[:value.shape[0], :].copy_(value) | ||
| elif '_B' in name: | ||
| target[:, :value.shape[1]].copy_(value) | ||
| else: | ||
| target.copy_(value) | ||
| return target | ||
|
|
||
| def _count_available_loras(self): | ||
| return len([_lora for _lora in self.loras if _lora.tenant_adapter_name is None]) | ||
|
|
||
|
|
@@ -472,7 +516,7 @@ def save_initial_weights(self): | |
| def _store_weights(_module): | ||
| for name, parameter in _module.named_parameters(): | ||
| if pattern.search(name): | ||
| lora_tenant.lora_A_weights[name] = parameter.data.clone().to('cpu') | ||
| lora_tenant.lora_A_weights[name] = self._read_param_tensor(parameter).clone().to('cpu') | ||
|
|
||
| if isinstance(self.module, list): | ||
| for _module in self.module: | ||
|
|
@@ -572,17 +616,9 @@ def save_lora_converter(self, name, parameter, adapter_name): | |
| # patching makes the bridge skip non-target modules entirely), so we | ||
| # only check the adapter-name / weight pattern here. | ||
| if re.search(rf'\.lora_\w+\.({adapter_name}|weight)', name): | ||
| _param = torch_util.to_local_tensor(parameter) | ||
| if _param is None: | ||
| pass | ||
| elif 'embedding_A' in name: | ||
| _param = _param[:, :_lora.tenant_config.r].clone() | ||
| elif 'embedding_B' in name: | ||
| _param = _param[:_lora.tenant_config.r, :].clone() | ||
| elif '_A' in name: | ||
| _param = _param[:_lora.tenant_config.r, :].clone() | ||
| elif '_B' in name: | ||
| _param = _param[:, :_lora.tenant_config.r].clone() | ||
| _param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r) | ||
| if _param is not None: | ||
| _param = _param.clone() | ||
| name = name.replace(f'.{_lora.adapter_name}.', '.') | ||
| return name, _param | ||
|
Comment on lines
+619
to
623
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to |
||
| else: | ||
|
|
@@ -595,20 +631,14 @@ def set_state_dict(self, tenant_adapter_name, state_dict): | |
| def _load_weights(_module): | ||
| for name, parameter in _module.named_parameters(): | ||
| if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules): | ||
| name = name.replace(f'.{_lora.adapter_name}.', '.') | ||
| src_tensor = state_dict[name] | ||
| if 'embedding_A' in name: | ||
| r_saved = src_tensor.shape[1] | ||
| parameter.data[:, :r_saved].copy_(src_tensor) | ||
| elif 'embedding_B' in name: | ||
| r_saved = src_tensor.shape[0] | ||
| parameter.data[:r_saved, :].copy_(src_tensor) | ||
| elif '_A' in name: | ||
| r_saved = src_tensor.shape[0] | ||
| parameter.data[:r_saved, :].copy_(src_tensor) | ||
| elif '_B' in name: | ||
| r_saved = src_tensor.shape[1] | ||
| parameter.data[:, :r_saved].copy_(src_tensor) | ||
| state_key = name.replace(f'.{_lora.adapter_name}.', '.') | ||
| target_tensor = self._read_param_tensor(parameter) | ||
| if target_tensor is None: | ||
| continue | ||
| target_tensor = target_tensor.clone() | ||
| src_tensor = state_dict[state_key].to(dtype=target_tensor.dtype, device=target_tensor.device) | ||
| self._copy_rank_tensor(name, target_tensor, src_tensor) | ||
| self._write_param_tensor(parameter, target_tensor) | ||
|
Comment on lines
+635
to
+641
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a mismatch between local and global tensors here.
To fix this, you should either shard |
||
|
|
||
| if isinstance(self.module, list): | ||
| for _module in self.module: | ||
|
|
@@ -625,15 +655,9 @@ def _get_weights(_module): | |
| state_dict = {} | ||
| for name, parameter in _module.named_parameters(): | ||
| if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules): | ||
| _param = torch_util.to_local_tensor(parameter) | ||
| if 'embedding_A' in name: | ||
| _param = _param[:, :_lora.tenant_config.r] | ||
| elif 'embedding_B' in name: | ||
| _param = _param[:_lora.tenant_config.r, :] | ||
| elif '_A' in name: | ||
| _param = _param[:_lora.tenant_config.r, :] | ||
| elif '_B' in name: | ||
| _param = _param[:, :_lora.tenant_config.r] | ||
| _param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r) | ||
|
kevssim marked this conversation as resolved.
|
||
| if _param is None: | ||
| continue | ||
| name = name.replace(f'.{_lora.adapter_name}.', '.') | ||
| state_dict[name] = _param | ||
|
Comment on lines
+658
to
662
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a distributed setting with FSDP2, |
||
| return state_dict | ||
|
|
@@ -653,9 +677,14 @@ def _load_initial_weights(self, origin_adapter_name): | |
| def _load_initial_weights(_module): | ||
| for name, parameter in _module.named_parameters(): | ||
| if pattern_A.search(name): | ||
| parameter.data.copy_(_lora.lora_A_weights[name]) | ||
| local_param = self._read_param_tensor(parameter) | ||
| if local_param is not None: | ||
| value = _lora.lora_A_weights[name].to(dtype=parameter.dtype, device=local_param.device) | ||
| self._write_param_tensor(parameter, value) | ||
| if pattern_B.search(name): | ||
| parameter.data.copy_(torch.zeros_like(parameter.data).to(parameter.data.dtype)) | ||
| local_param = self._read_param_tensor(parameter) | ||
| if local_param is not None: | ||
| self._write_param_tensor(parameter, torch.zeros_like(local_param)) | ||
|
|
||
| if isinstance(self.module, list): | ||
| for _module in self.module: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of
distribute_tensorhere assumes thatvalueis a global tensor that needs to be sharded according to thedevice_meshandplacements. However, in several call sites (likeset_state_dictand_load_initial_weights), thevaluepassed to_write_param_tensoris derived from_read_param_tensor, which returns a local shard. Callingdistribute_tensoron a local shard will incorrectly attempt to shard the shard again, leading to incorrect parameter values in distributed training.