-
Notifications
You must be signed in to change notification settings - Fork 222
[tx] [WIP] Add SkyRL-train backend #871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new SkyRL-train backend for supervised training. The changes include updating project dependencies in pyproject.toml and adding the new backend implementation in skyrl-tx/tx/tinker/backends/skyrl_train.py. While this is a good starting point for the new backend, my review has identified several issues that need to be addressed. The most critical issue is in the forward_backward method, which is currently a stub and does not perform a backward pass or return actual losses, preventing any training from occurring. Other significant issues include the use of hardcoded paths and hyperparameters, potentially incorrect token padding, and breaking encapsulation by accessing private members of a library class. Addressing these points will be crucial for the backend to be functional and maintainable.
| def forward_backward( | ||
| self, prepared_batch: types.PreparedModelPassBatch, | ||
| ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: | ||
| if not prepared_batch.all_input_ids: | ||
| return {} | ||
|
|
||
| batch = self._to_training_batch(prepared_batch) | ||
| output = self._actor_group.run_method("mesh", "forward", batch) | ||
|
|
||
| results = {} | ||
| for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices: | ||
| loss_fn_outputs = [] | ||
| for i in range(start_idx, end_idx): | ||
| seq_len = len(prepared_batch.all_input_ids[i]) | ||
| loss_fn_outputs.append({ | ||
| "elementwise_loss": {"data": [0.0] * seq_len, "dtype": "float32", "shape": [seq_len]}, | ||
| "logprobs": {"data": [0.0] * seq_len, "dtype": "float32", "shape": [seq_len]}, | ||
| }) | ||
| results[request_id] = types.ForwardBackwardOutput( | ||
| loss_fn_output_type="scalar", loss_fn_outputs=loss_fn_outputs, metrics={}, | ||
| ) | ||
| return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward_backward method is not correctly implemented and appears to be a stub. It calls forward on the worker (which likely only performs a forward pass and does not compute gradients), ignores the returned output, and then constructs a dummy result with zeroed-out losses. This is a critical issue as it prevents any actual gradient computation and backpropagation, meaning no training will occur. The implementation needs to be updated to:
- Call a worker method that performs both a forward and backward pass (e.g.,
ppo_trainorforward_backwardon the worker). - Process the actual output from the workers to return correct loss values and other metrics.
| "rank": lora_config.rank if lora_config else 0, | ||
| "alpha": lora_config.alpha if lora_config else 16, | ||
| "dropout": 0, | ||
| "lora_sync_path": "/tmp/skyrl_lora_sync", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoding paths like /tmp/skyrl_lora_sync can lead to conflicts in multi-user environments and unexpected data loss if /tmp is cleared. This path should be made configurable, for example by adding it to SkyRLTrainBackendConfig or using Python's tempfile module to create a secure temporary directory.
| "micro_forward_batch_size_per_gpu": 4, | ||
| "policy_mini_batch_size": 256, | ||
| "flash_attn": True, "use_sample_packing": False, | ||
| "ckpt_path": "/tmp/skyrl_ckpts", "logger": "console", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| for seq, weights in zip(prepared_batch.all_input_ids, prepared_batch.all_token_weights): | ||
| pad_len = max_len - len(seq) | ||
| sequences.append([0] * pad_len + list(seq)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skyrl-tx/pyproject.toml
Outdated
| # requires-python = ">=3.11" | ||
| requires-python = ">=3.12, <3.13" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return OmegaConf.create({ | ||
| "trainer": { | ||
| "placement": {"colocate_all": False, "policy_num_nodes": 1, "policy_num_gpus_per_node": config.num_gpus}, | ||
| "strategy": "fsdp2", | ||
| "policy": { | ||
| "model": {"path": base_model, "lora": lora_cfg}, | ||
| "optimizer_config": { | ||
| "lr": 1e-5, "adam_betas": [0.9, 0.999], "weight_decay": 0.01, | ||
| "max_grad_norm": 1.0, "offload_after_step": False, "num_warmup_steps": 0, "scheduler": "constant_with_warmup", | ||
| }, | ||
| "fsdp_config": {"cpu_offload": False, "reshard_after_forward": True, "fsdp_size": -1}, | ||
| "sequence_parallel_size": 1, "use_torch_compile": False, "record_memory": False, "model_config_kwargs": {}, | ||
| }, | ||
| "algorithm": { | ||
| "policy_loss_type": "regular", "loss_reduction": "token_mean", | ||
| "eps_clip_low": 0.2, "eps_clip_high": 0.2, | ||
| "use_kl_loss": False, "kl_loss_coef": 0.0, "use_entropy_loss": False, "entropy_loss_coef": 0.0, | ||
| }, | ||
| "gradient_checkpointing": True, "gradient_checkpointing_use_reentrant": False, | ||
| "seed": 42, "bf16": True, | ||
| "micro_train_batch_size_per_gpu": config.micro_train_batch_size_per_gpu, | ||
| "micro_forward_batch_size_per_gpu": 4, | ||
| "policy_mini_batch_size": 256, | ||
| "flash_attn": True, "use_sample_packing": False, | ||
| "ckpt_path": "/tmp/skyrl_ckpts", "logger": "console", | ||
| }, | ||
| "generator": { | ||
| "n_samples_per_prompt": 1, | ||
| "sampling_params": {"temperature": 1.0}, | ||
| "weight_transfer_threshold_cuda_ipc_GB": 1.0, | ||
| }, | ||
| }) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many training hyperparameters (e.g., learning rate, optimizer settings, loss configuration) are hardcoded within the _build_config function. This significantly reduces the flexibility of the backend, as users cannot easily tune these parameters for different experiments. It is recommended to expose these settings through the SkyRLTrainBackendConfig class.
| ray.get([actor.save_checkpoint.remote(output_path) for actor in self._actor_group._actor_handlers]) | ||
|
|
||
| def load_checkpoint(self, checkpoint_path, model_id: str) -> None: | ||
| if model_id != self._model_id: | ||
| raise ValueError(f"Model {model_id} not found") | ||
| ray.get([actor.load_checkpoint.remote(Path(checkpoint_path)) for actor in self._actor_group._actor_handlers]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing the private member _actor_handlers of PPORayActorGroup breaks encapsulation and makes the code dependent on the internal implementation of the skyrl-train library. This could lead to breakages if the library is updated. It would be more robust to use a public API from PPORayActorGroup for this purpose, or request one if it doesn't exist.
No description provided.