Skip to content

Conversation

@pcmoritz
Copy link
Collaborator

No description provided.

@pcmoritz pcmoritz added the tx label Jan 13, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new SkyRL-train backend for supervised training. The changes include updating project dependencies in pyproject.toml and adding the new backend implementation in skyrl-tx/tx/tinker/backends/skyrl_train.py. While this is a good starting point for the new backend, my review has identified several issues that need to be addressed. The most critical issue is in the forward_backward method, which is currently a stub and does not perform a backward pass or return actual losses, preventing any training from occurring. Other significant issues include the use of hardcoded paths and hyperparameters, potentially incorrect token padding, and breaking encapsulation by accessing private members of a library class. Addressing these points will be crucial for the backend to be functional and maintainable.

Comment on lines +147 to +168
def forward_backward(
self, prepared_batch: types.PreparedModelPassBatch,
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
if not prepared_batch.all_input_ids:
return {}

batch = self._to_training_batch(prepared_batch)
output = self._actor_group.run_method("mesh", "forward", batch)

results = {}
for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices:
loss_fn_outputs = []
for i in range(start_idx, end_idx):
seq_len = len(prepared_batch.all_input_ids[i])
loss_fn_outputs.append({
"elementwise_loss": {"data": [0.0] * seq_len, "dtype": "float32", "shape": [seq_len]},
"logprobs": {"data": [0.0] * seq_len, "dtype": "float32", "shape": [seq_len]},
})
results[request_id] = types.ForwardBackwardOutput(
loss_fn_output_type="scalar", loss_fn_outputs=loss_fn_outputs, metrics={},
)
return results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward_backward method is not correctly implemented and appears to be a stub. It calls forward on the worker (which likely only performs a forward pass and does not compute gradients), ignores the returned output, and then constructs a dummy result with zeroed-out losses. This is a critical issue as it prevents any actual gradient computation and backpropagation, meaning no training will occur. The implementation needs to be updated to:

  1. Call a worker method that performs both a forward and backward pass (e.g., ppo_train or forward_backward on the worker).
  2. Process the actual output from the workers to return correct loss values and other metrics.

"rank": lora_config.rank if lora_config else 0,
"alpha": lora_config.alpha if lora_config else 16,
"dropout": 0,
"lora_sync_path": "/tmp/skyrl_lora_sync",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding paths like /tmp/skyrl_lora_sync can lead to conflicts in multi-user environments and unexpected data loss if /tmp is cleared. This path should be made configurable, for example by adding it to SkyRLTrainBackendConfig or using Python's tempfile module to create a secure temporary directory.

"micro_forward_batch_size_per_gpu": 4,
"policy_mini_batch_size": 256,
"flash_attn": True, "use_sample_packing": False,
"ckpt_path": "/tmp/skyrl_ckpts", "logger": "console",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding the checkpoint path to /tmp/skyrl_ckpts can cause issues in multi-user systems and may lead to data being lost when the temporary directory is cleaned. It's better to make this path configurable, for instance, through SkyRLTrainBackendConfig.


for seq, weights in zip(prepared_batch.all_input_ids, prepared_batch.all_token_weights):
pad_len = max_len - len(seq)
sequences.append([0] * pad_len + list(seq))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The padding value for sequences is hardcoded to 0. This may not be the correct pad_token_id for all models and could lead to incorrect behavior during training. This backend should have access to the model's tokenizer to use the correct pad_token_id for padding.

Comment on lines 10 to 11
# requires-python = ">=3.11"
requires-python = ">=3.12, <3.13"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out requires-python line is now dead code and should be removed to keep the project configuration clean.

Suggested change
# requires-python = ">=3.11"
requires-python = ">=3.12, <3.13"
requires-python = ">=3.12, <3.13"

Comment on lines 44 to 75
return OmegaConf.create({
"trainer": {
"placement": {"colocate_all": False, "policy_num_nodes": 1, "policy_num_gpus_per_node": config.num_gpus},
"strategy": "fsdp2",
"policy": {
"model": {"path": base_model, "lora": lora_cfg},
"optimizer_config": {
"lr": 1e-5, "adam_betas": [0.9, 0.999], "weight_decay": 0.01,
"max_grad_norm": 1.0, "offload_after_step": False, "num_warmup_steps": 0, "scheduler": "constant_with_warmup",
},
"fsdp_config": {"cpu_offload": False, "reshard_after_forward": True, "fsdp_size": -1},
"sequence_parallel_size": 1, "use_torch_compile": False, "record_memory": False, "model_config_kwargs": {},
},
"algorithm": {
"policy_loss_type": "regular", "loss_reduction": "token_mean",
"eps_clip_low": 0.2, "eps_clip_high": 0.2,
"use_kl_loss": False, "kl_loss_coef": 0.0, "use_entropy_loss": False, "entropy_loss_coef": 0.0,
},
"gradient_checkpointing": True, "gradient_checkpointing_use_reentrant": False,
"seed": 42, "bf16": True,
"micro_train_batch_size_per_gpu": config.micro_train_batch_size_per_gpu,
"micro_forward_batch_size_per_gpu": 4,
"policy_mini_batch_size": 256,
"flash_attn": True, "use_sample_packing": False,
"ckpt_path": "/tmp/skyrl_ckpts", "logger": "console",
},
"generator": {
"n_samples_per_prompt": 1,
"sampling_params": {"temperature": 1.0},
"weight_transfer_threshold_cuda_ipc_GB": 1.0,
},
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Many training hyperparameters (e.g., learning rate, optimizer settings, loss configuration) are hardcoded within the _build_config function. This significantly reduces the flexibility of the backend, as users cannot easily tune these parameters for different experiments. It is recommended to expose these settings through the SkyRLTrainBackendConfig class.

Comment on lines +192 to +197
ray.get([actor.save_checkpoint.remote(output_path) for actor in self._actor_group._actor_handlers])

def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
if model_id != self._model_id:
raise ValueError(f"Model {model_id} not found")
ray.get([actor.load_checkpoint.remote(Path(checkpoint_path)) for actor in self._actor_group._actor_handlers])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing the private member _actor_handlers of PPORayActorGroup breaks encapsulation and makes the code dependent on the internal implementation of the skyrl-train library. This could lead to breakages if the library is updated. It would be more robust to use a public API from PPORayActorGroup for this purpose, or request one if it doesn't exist.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant