Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions roll/distributed/executor/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def _create_workers(self):
if "ROLL_LOG_DIR" in os.environ:
env_vars["ROLL_LOG_DIR"] = os.environ["ROLL_LOG_DIR"]
env_vars.update(self.worker_config.system_envs)
if current_platform.is_npu():
env_vars["HCCL_HOST_SOCKET_PORT_RANGE"] = "auto"
env_vars["HCCL_NPU_SOCKET_PORT_RANGE"] = "auto"

runtime_env = RuntimeEnv(env_vars=env_vars)
self.worker_config.resource_placement_groups = pgs
Expand Down
28 changes: 23 additions & 5 deletions roll/third_party/deepspeed/model_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,26 @@ def _gather_weights(is_zero3, named_params):
return [(n, p.data) for n, p in named_params]


def gather_deepspeed_weights(model, ds_config, buffer_size):
def gather_deepspeed_weights(model, ds_config, buffer_size, is_lora=False):
is_zero3 = ds_config.is_zero3()
named_params = [(name, param) for name, param in model.named_parameters()]
if is_lora:
if not is_zero3:
from peft.utils import get_peft_model_state_dict
lora_state_dict = get_peft_model_state_dict(model)
named_params = [(name, param) for name, param in lora_state_dict.items()]
else:
adapter_name = "default"
state_dict = model.state_dict()
lora_state_dict = {k: state_dict[k] for k in state_dict if ("lora_" in k and adapter_name in k)}
named_params = []
for name, param in lora_state_dict.items():
clean_name = name.replace(f".{adapter_name}", "")
if clean_name.startswith("base_model.model."):
clean_name = clean_name[len("base_model.model."):]
named_params.append((clean_name, model.get_parameter(name)))
del lora_state_dict
else:
named_params = [(name, param) for name, param in model.named_parameters()]

waiting_params, waiting_params_size = [], 0
for name, param in named_params:
Expand Down Expand Up @@ -150,7 +167,7 @@ def _setup_broadcast_group(self):
def _colocated_model_update(self):
refs = []
for named_weights in gather_deepspeed_weights(
self.model, self.ds_config, buffer_size=self._model_update_buffer_size
self.model, self.ds_config, buffer_size=self._model_update_buffer_size, is_lora=self.is_lora
):
serialized_tensors = serialize_named_weights(
named_weights, infer_strategy=self.infer_worker_config.strategy_args.strategy_name
Expand All @@ -167,7 +184,7 @@ def _colocated_model_update(self):
ray.get(refs)
refs = []
if co_infer_rank == 0 and self._co_infer_worker is not None:
refs.append(self._co_infer_worker.update_parameter_in_bucket.remote(infer_parallel_tensors))
refs.append(self._co_infer_worker.update_parameter_in_bucket.remote(infer_parallel_tensors, is_lora=self.is_lora))
if self._broadcast_workers:
refs.extend(self._broadcast_to_infer_workers(named_weights))
if refs:
Expand All @@ -183,6 +200,7 @@ def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]:
names=[n for n, _ in named_weights],
dtypes=[w.dtype for _, w in named_weights],
shapes=[w.shape for _, w in named_weights],
is_lora=self.is_lora,
)
for worker in self._broadcast_workers
]
Expand All @@ -198,7 +216,7 @@ def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]:
def _separated_model_update(self):
logger.info(f"start broadcast model update {self.model_update_group_name}")
for named_weights in gather_deepspeed_weights(
self.model, self.ds_config, buffer_size=self._model_update_buffer_size
self.model, self.ds_config, buffer_size=self._model_update_buffer_size, is_lora=self.is_lora
):
refs = self._broadcast_to_infer_workers(named_weights)
ray.get(refs)
Expand Down