Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 66 additions & 37 deletions src/twinkle/model/multi_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass, field
from peft import LoraConfig, PeftModel, get_peft_model
from peft.tuners.lora import Embedding, Linear, LoraLayer
from torch.distributed.tensor import distribute_tensor
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -42,6 +43,49 @@ def _get_available_lora(self) -> Optional[LoraTenant]:
return _lora
return None

def _read_param_tensor(self, parameter):
return torch_util.to_local_tensor(parameter)

def _write_param_tensor(self, parameter, value):
if value is None:
return
value = value.detach().to(dtype=parameter.dtype)
if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'):
value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of distribute_tensor here assumes that value is a global tensor that needs to be sharded according to the device_mesh and placements. However, in several call sites (like set_state_dict and _load_initial_weights), the value passed to _write_param_tensor is derived from _read_param_tensor, which returns a local shard. Calling distribute_tensor on a local shard will incorrectly attempt to shard the shard again, leading to incorrect parameter values in distributed training.

else:
value = value.to(parameter.device)
parameter.data.copy_(value)
Comment thread
kevssim marked this conversation as resolved.

@staticmethod
def _slice_rank_tensor(name: str, tensor, rank: int):
if tensor is None:
return None
if 'embedding_A' in name:
return tensor[:, :rank]
if 'embedding_B' in name:
return tensor[:rank, :]
if '_A' in name:
return tensor[:rank, :]
if '_B' in name:
return tensor[:, :rank]
return tensor

@staticmethod
def _copy_rank_tensor(name: str, target, value):
if target is None or value is None:
return None
if 'embedding_A' in name:
target[:, :value.shape[1]].copy_(value)
elif 'embedding_B' in name:
target[:value.shape[0], :].copy_(value)
elif '_A' in name:
target[:value.shape[0], :].copy_(value)
elif '_B' in name:
target[:, :value.shape[1]].copy_(value)
else:
target.copy_(value)
return target

def _count_available_loras(self):
return len([_lora for _lora in self.loras if _lora.tenant_adapter_name is None])

Expand Down Expand Up @@ -472,7 +516,7 @@ def save_initial_weights(self):
def _store_weights(_module):
for name, parameter in _module.named_parameters():
if pattern.search(name):
lora_tenant.lora_A_weights[name] = parameter.data.clone().to('cpu')
lora_tenant.lora_A_weights[name] = self._read_param_tensor(parameter).clone().to('cpu')

if isinstance(self.module, list):
for _module in self.module:
Expand Down Expand Up @@ -572,17 +616,9 @@ def save_lora_converter(self, name, parameter, adapter_name):
# patching makes the bridge skip non-target modules entirely), so we
# only check the adapter-name / weight pattern here.
if re.search(rf'\.lora_\w+\.({adapter_name}|weight)', name):
_param = torch_util.to_local_tensor(parameter)
if _param is None:
pass
elif 'embedding_A' in name:
_param = _param[:, :_lora.tenant_config.r].clone()
elif 'embedding_B' in name:
_param = _param[:_lora.tenant_config.r, :].clone()
elif '_A' in name:
_param = _param[:_lora.tenant_config.r, :].clone()
elif '_B' in name:
_param = _param[:, :_lora.tenant_config.r].clone()
_param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r)
if _param is not None:
_param = _param.clone()
name = name.replace(f'.{_lora.adapter_name}.', '.')
return name, _param
Comment on lines +619 to 623
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to get_state_dict, save_lora_converter is returning a local shard of the parameter. When saving the model, this will result in a checkpoint containing sharded LoRA weights, which is incompatible with standard PEFT loaders. LoRA weights are typically small enough to be gathered and saved as full tensors even in FSDP environments.

else:
Expand All @@ -595,20 +631,14 @@ def set_state_dict(self, tenant_adapter_name, state_dict):
def _load_weights(_module):
for name, parameter in _module.named_parameters():
if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules):
name = name.replace(f'.{_lora.adapter_name}.', '.')
src_tensor = state_dict[name]
if 'embedding_A' in name:
r_saved = src_tensor.shape[1]
parameter.data[:, :r_saved].copy_(src_tensor)
elif 'embedding_B' in name:
r_saved = src_tensor.shape[0]
parameter.data[:r_saved, :].copy_(src_tensor)
elif '_A' in name:
r_saved = src_tensor.shape[0]
parameter.data[:r_saved, :].copy_(src_tensor)
elif '_B' in name:
r_saved = src_tensor.shape[1]
parameter.data[:, :r_saved].copy_(src_tensor)
state_key = name.replace(f'.{_lora.adapter_name}.', '.')
target_tensor = self._read_param_tensor(parameter)
if target_tensor is None:
continue
target_tensor = target_tensor.clone()
src_tensor = state_dict[state_key].to(dtype=target_tensor.dtype, device=target_tensor.device)
self._copy_rank_tensor(name, target_tensor, src_tensor)
self._write_param_tensor(parameter, target_tensor)
Comment on lines +635 to +641
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a mismatch between local and global tensors here. target_tensor is a local shard (from _read_param_tensor), while src_tensor is a global tensor (from state_dict).

  1. _copy_rank_tensor (line 640) will fail with a shape mismatch if the dimension being copied is sharded by FSDP (e.g., num_embeddings for embeddings or out_features for linear layers).
  2. _write_param_tensor (line 641) will then attempt to shard this local shard again as discussed in the previous comment.

To fix this, you should either shard src_tensor to match the local shard's placements before copying, or perform the copy on global tensors (on CPU) and then use _write_param_tensor to shard the result.


if isinstance(self.module, list):
for _module in self.module:
Expand All @@ -625,15 +655,9 @@ def _get_weights(_module):
state_dict = {}
for name, parameter in _module.named_parameters():
if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules):
_param = torch_util.to_local_tensor(parameter)
if 'embedding_A' in name:
_param = _param[:, :_lora.tenant_config.r]
elif 'embedding_B' in name:
_param = _param[:_lora.tenant_config.r, :]
elif '_A' in name:
_param = _param[:_lora.tenant_config.r, :]
elif '_B' in name:
_param = _param[:, :_lora.tenant_config.r]
_param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r)
Comment thread
kevssim marked this conversation as resolved.
if _param is None:
continue
name = name.replace(f'.{_lora.adapter_name}.', '.')
state_dict[name] = _param
Comment on lines +658 to 662
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In a distributed setting with FSDP2, _read_param_tensor returns a local shard. Consequently, get_state_dict returns a dictionary of sharded tensors. Since this method is decorated with @remote_function(collect='first') in the model class, only rank 0's local shards will be returned to the caller. This results in an incomplete and unusable state dict for the LoRA adapter. You should gather the shards into a global tensor before slicing and returning them.

return state_dict
Expand All @@ -653,9 +677,14 @@ def _load_initial_weights(self, origin_adapter_name):
def _load_initial_weights(_module):
for name, parameter in _module.named_parameters():
if pattern_A.search(name):
parameter.data.copy_(_lora.lora_A_weights[name])
local_param = self._read_param_tensor(parameter)
if local_param is not None:
value = _lora.lora_A_weights[name].to(dtype=parameter.dtype, device=local_param.device)
self._write_param_tensor(parameter, value)
if pattern_B.search(name):
parameter.data.copy_(torch.zeros_like(parameter.data).to(parameter.data.dtype))
local_param = self._read_param_tensor(parameter)
if local_param is not None:
self._write_param_tensor(parameter, torch.zeros_like(local_param))

if isinstance(self.module, list):
for _module in self.module:
Expand Down
42 changes: 27 additions & 15 deletions src/twinkle/model/transformers/multi_lora_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
import torch.distributed as dist
import transformers
from peft import LoraConfig, PeftConfig, PeftModel, load_peft_weights
from torch.optim import Optimizer
Expand All @@ -15,7 +16,6 @@
from twinkle.metric import Metric
from twinkle.processor import InputProcessor
from ..multi_lora import MultiLora
from .strategy import AccelerateStrategy
from .transformers import OptimizerGroup, TransformersModel


Expand All @@ -29,17 +29,28 @@ def __init__(
config: Optional[PretrainedConfig] = None,
device_mesh: Optional[DeviceMesh] = None,
mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
strategy: Literal['accelerate', 'native_fsdp'] = 'accelerate',
ddp_config: Dict[str, Any] = None,
fsdp_config: Dict[str, Any] = None,
grad_scaler_config: Dict[str, Any] = None,
memory_efficient_init: bool = False,
max_loras: int = 5,
max_r: int = 32,
max_length: int = 8192,
target_modules: Union[List[str], str] = 'all-linear',
**kwargs):
assert device_mesh.fsdp_world_size <= 0, f'MultiLora does not support FSDP, current is: {str(device_mesh)}'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
self._try_init_process_group()
super(PreTrainedModel, self).__init__()
Comment thread
kevssim marked this conversation as resolved.
model_id = HubOperation.download_model(model_id)
self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
self._fsdp_config = dict(fsdp_config or {})
self._ddp_config = ddp_config or {}
self._memory_efficient_init = memory_efficient_init
self._decide_strategy(strategy)
self.grad_scaler_config = grad_scaler_config
if model_id is not None:
model_id = HubOperation.download_model(model_id)
self.model_id = model_id
if config is None:
from transformers import AutoConfig
Expand All @@ -52,24 +63,20 @@ def __init__(
model_cls = AutoModelForCausalLM
if isinstance(model_cls, str):
model_cls = getattr(transformers, model_cls)
self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs)
self.model_id = model_id
if model_id is None:
self.model = model_cls.from_config(self.hf_config, **kwargs)
else:
with self.strategy.pretrained_load_context():
self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs)
self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id)
self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
self.grad_scaler_config = grad_scaler_config
self._default_tokenizer = None
self._model_wrapped = False
self.sp_strategy = None
# Initialize expert parallel attributes (required by set_optimizer in TransformersModel)
self._expert_parallel_config = None
self._enable_expert_parallel = False
self._expert_parallel_applied = False
self.optimizer_group: Dict[str, OptimizerGroup] = {}
self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length)
self.model.gradient_checkpointing_enable()
self.model = self.multi_adapter.patch(self.model, target_modules=target_modules)
self.strategy = AccelerateStrategy(mixed_precision=mixed_precision, device_mesh=None)
self.model = self.strategy.wrap_model(self.model)
self.multi_adapter.save_initial_weights()
# Active group for compatibility with single adapter
self.active_group = None
Expand Down Expand Up @@ -100,7 +107,7 @@ def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup):
pass

def _lazy_wrap_model(self):
pass
return super()._lazy_wrap_model()

@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
Expand Down Expand Up @@ -232,7 +239,10 @@ def get_state_dict(self, **kwargs):
def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs):
self._check_adapter_valid(kwargs.get('adapter_name'))
with self.multi_adapter.save_context(kwargs.get('adapter_name')):
return super().save(name, output_dir, interval, **kwargs)
checkpoint_dir = super().save(name, output_dir, interval, **kwargs)
if dist.is_initialized():
dist.barrier()
return checkpoint_dir

@remote_function()
def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
Expand All @@ -256,6 +266,8 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):

if load_optimizer:
self._restore_training_state(checkpoint_dir, adapter_name=adapter_name)
if dist.is_initialized():
dist.barrier()

@remote_function()
def set_grad_scaler(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions src/twinkle/model/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ def _restore_training_state(self, checkpoint_dir, *, adapter_name=''):

return trainer_state

@remote_function()
@remote_function(dispatch='all', collect='first', sync=True)
def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs):
adapter_name = kwargs.get('adapter_name', '')

Expand Down Expand Up @@ -1286,7 +1286,7 @@ def _get_trainable_parameters_example(self, adapter_name, model):
trainable_param_names = '\n'.join(trainable_param_names)
return trainable_param_names

@remote_function(execute='first', lazy_collect=False)
@remote_function(dispatch='all', collect='first', lazy_collect=False)
def get_train_configs(self, **kwargs) -> str:
expr = ''
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
Expand Down
Loading