Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
159 commits
Select commit Hold shift + click to select a range
0e2d986
plan
ZhiyuLi-Nvidia May 1, 2026
03fd24e
plan: align Stage 4 with rl-arena/verl 1-hop pattern
ZhiyuLi-Nvidia May 2, 2026
7aedea8
feat(data-plane): TransferQueue integration for GRPO with driver-side…
ZhiyuLi-Nvidia May 4, 2026
9c17127
refactor(data-plane): extract driver-side balanced packing into presh…
ZhiyuLi-Nvidia May 5, 2026
f1a995b
feat(data-plane): AsyncTrajectoryCollector writes rollouts to TQ when…
ZhiyuLi-Nvidia May 5, 2026
ec7df8f
feat(data-plane): wire async-on-TQ end-to-end with driver-side balanc…
ZhiyuLi-Nvidia May 5, 2026
49db9bb
fix(data-plane): preserve sample order and FLOPs semantics on @dp_dis…
ZhiyuLi-Nvidia May 5, 2026
130f713
feat(data-plane): grpo_sync routes logprob/ref-logprob through @dp_di…
ZhiyuLi-Nvidia May 5, 2026
5e26441
refactor(data-plane): replace @dp_dispatch with TQPolicy subclass; ad…
ZhiyuLi-Nvidia May 5, 2026
bd714c8
fix(data-plane): VLM extras, async fan-out, cleanup-on-failure
ZhiyuLi-Nvidia May 5, 2026
f2a8ba3
docs(data-plane): add API lifecycle doc with verl comparison
ZhiyuLi-Nvidia May 7, 2026
680e5dd
feat(data-plane): sync 1-hop trajectory collector + per-sample key li…
ZhiyuLi-Nvidia May 7, 2026
8b297f8
refactor(data-plane): extract make_actor_runtime_env, fix N² list copy
ZhiyuLi-Nvidia May 7, 2026
941b54d
feat(data-plane): jagged tensors on TQ wire + naming/factory cleanup
ZhiyuLi-Nvidia May 7, 2026
975bd05
refactor(data-plane): KVBatchMeta.subset/slice/concat methods
ZhiyuLi-Nvidia May 7, 2026
186e792
Mooncake cpu backend
ZhiyuLi-Nvidia May 7, 2026
fabb9a0
Readability Refactor
ZhiyuLi-Nvidia May 8, 2026
eb643c0
wip test mooncake
ZhiyuLi-Nvidia May 8, 2026
b32ffa3
refactor(data-plane): drop dead set_wire_format/_PACK_JAGGED + adapte…
ZhiyuLi-Nvidia May 8, 2026
68454ff
refactor(ray.sub): drop NETWORK_INIT_CMDS — MC_TCP_BIND_ADDRESS suffices
ZhiyuLi-Nvidia May 8, 2026
2486160
docs(data-plane): consolidate README; drop stale plan/verl refs
ZhiyuLi-Nvidia May 8, 2026
c2c53d6
feat(data-plane): non-tensor object support on TQ wire
ZhiyuLi-Nvidia May 8, 2026
2ba6ef2
feat(grpo-sync): equivalency fixes + content via TQ object column
ZhiyuLi-Nvidia May 9, 2026
72835c6
style: fix ruff lint errors and apply ruff format
ZhiyuLi-Nvidia May 9, 2026
74ddeba
style: apply pre-commit auto-fixes (ruff)
ZhiyuLi-Nvidia May 9, 2026
3b9b827
chore(pyrefly): whitelist all new data_plane files + fix type errors
ZhiyuLi-Nvidia May 9, 2026
9dbca27
remove unnecessary script
ZhiyuLi-Nvidia May 9, 2026
3533d53
feat(data-plane): decompose message_log at wire boundary
ZhiyuLi-Nvidia May 12, 2026
4a8096c
refactor(data-plane): rename DataPlaneClient.get_meta → claim_meta
ZhiyuLi-Nvidia May 12, 2026
87414f8
docs(data-plane): tighten DataPlaneClient boundary docstring
ZhiyuLi-Nvidia May 12, 2026
9ca449f
fix(data-plane): treat DataPlaneConfig.enabled as required field
ZhiyuLi-Nvidia May 12, 2026
a123ffe
docs(data-plane): make build_data_plane_client docstring backend-agno…
ZhiyuLi-Nvidia May 12, 2026
92c8244
refactor(data-plane): promote codec imports to module top-level
ZhiyuLi-Nvidia May 12, 2026
efdd82c
refactor(data-plane): rename driver_io → column_io
ZhiyuLi-Nvidia May 12, 2026
0d92835
refactor(data-plane): validate dp_world at TQPolicy config time
ZhiyuLi-Nvidia May 12, 2026
861294f
refactor(data-plane): centralize packing-meta keys in schema.py
ZhiyuLi-Nvidia May 13, 2026
cb36ef6
refactor(data-plane): drop redundant dp_world assert in shard_meta_fo…
ZhiyuLi-Nvidia May 13, 2026
dbbfd19
refactor(data-plane): move DP_SEED_FIELDS to schema.py as DP_TRAIN_FI…
ZhiyuLi-Nvidia May 13, 2026
a646aeb
fix(data-plane): reject empty meta in shard_meta_for_dp
ZhiyuLi-Nvidia May 13, 2026
81cbcd7
refactor(data-plane): print_event → log_event via stdlib logging
ZhiyuLi-Nvidia May 13, 2026
808f165
style(data-plane): match repo logger naming convention
ZhiyuLi-Nvidia May 13, 2026
0a330d3
refactor(data-plane): convert DataPlaneStats to @dataclass
ZhiyuLi-Nvidia May 13, 2026
028dff8
refactor(data-plane): type DataPlaneEvent as TypedDict
ZhiyuLi-Nvidia May 13, 2026
b1de50f
refactor(data-plane): drop placeholder 0s from _run; make sizes kw-only
ZhiyuLi-Nvidia May 13, 2026
fd47991
fix(data-plane): route check_consumption_status through _run
ZhiyuLi-Nvidia May 13, 2026
fc7d6e5
fix(data-plane): route close() through _run
ZhiyuLi-Nvidia May 13, 2026
4d94024
perf(data-plane): single sync in to_nested_by_length
ZhiyuLi-Nvidia May 13, 2026
0f69257
docs(data-plane): convert codec.py docstrings to Google style
ZhiyuLi-Nvidia May 13, 2026
e72682e
refactor(data-plane): centralize Layout type alias in schema.py
ZhiyuLi-Nvidia May 13, 2026
5e3f2d3
fix(data-plane): validate pad_to_multiple >= 1 in materialize
ZhiyuLi-Nvidia May 13, 2026
3c6f7ca
fix(data-plane): fail fast on empty local IP at Mooncake bootstrap
ZhiyuLi-Nvidia May 13, 2026
d86d84b
fix(data-plane): surface chmod failure when mooncake_master is not exec
ZhiyuLi-Nvidia May 13, 2026
13ae181
refactor(data-plane): scope mooncake_cpu 1D workaround to TQDataPlane…
ZhiyuLi-Nvidia May 13, 2026
c39c580
docs(data-plane): clarify TQ module vs client access convention
ZhiyuLi-Nvidia May 13, 2026
3b1d196
docs(data-plane): note trust boundary at pack_object_array pickle site
ZhiyuLi-Nvidia May 13, 2026
4fda233
refactor(data-plane): drop codec pickle, use TQ-native NonTensorStack
ZhiyuLi-Nvidia May 13, 2026
5e36cd2
refactor(data-plane): drop dead object-array codec helpers
ZhiyuLi-Nvidia May 13, 2026
a11a8cd
refactor(data-plane): centralize _meta_idx sentinel in schema.py
ZhiyuLi-Nvidia May 13, 2026
5669bbb
docs(data-plane): convert interfaces.py docstrings to Google style
ZhiyuLi-Nvidia May 13, 2026
f602e71
refactor(data-plane): align schema constant names with their values
ZhiyuLi-Nvidia May 13, 2026
d2d6e98
docs(data-plane): tighten preshard.py docstring to Google style
ZhiyuLi-Nvidia May 13, 2026
cd022c0
docs(data-plane): convert column_io.py docstrings to Google style
ZhiyuLi-Nvidia May 13, 2026
a7b14a8
docs(data-plane): convert factory.py docstring to Google style
ZhiyuLi-Nvidia May 13, 2026
a28e116
docs(data-plane): add Args/Returns blocks to observability.py docstrings
ZhiyuLi-Nvidia May 13, 2026
941c084
docs(data-plane): tighten transfer_queue.py docstrings, add Args/Retu…
ZhiyuLi-Nvidia May 13, 2026
ac171dc
docs(data-plane): add Args/Returns to worker_mixin.py docstrings
ZhiyuLi-Nvidia May 13, 2026
4573641
docs(data-plane): add Args/Returns blocks to tq_policy.py docstrings
ZhiyuLi-Nvidia May 13, 2026
f3ac950
docs(data-plane): convert sync_rollout_actor.py docstrings to Google …
ZhiyuLi-Nvidia May 13, 2026
c69cbd0
docs(data-plane): add Args/Returns to grpo_sync.py dynamic-sampling h…
ZhiyuLi-Nvidia May 13, 2026
18ec172
refactor(data-plane): drop _to_wire's redundant promote_1d kwarg
ZhiyuLi-Nvidia May 13, 2026
d0e8fdb
fix(data-plane): survive TQ simple-backend NonTensorData wire-strip
ZhiyuLi-Nvidia May 14, 2026
9cd0c3a
build(data-plane): pin mooncake-transfer-engine-cuda13 wheel for cu13…
ZhiyuLi-Nvidia May 14, 2026
800f89b
chore: ruff auto-fix and ruff-format pass
ZhiyuLi-Nvidia May 14, 2026
1d7d0ee
chore(pyrefly): rename driver_io → column_io in whitelist
ZhiyuLi-Nvidia May 14, 2026
2a6285c
chore(pyrefly): silence 5 latent type errors with targeted ignore com…
ZhiyuLi-Nvidia May 14, 2026
e06076e
chore(pyrefly): whitelist nemo_rl/data_plane/schema.py
ZhiyuLi-Nvidia May 14, 2026
81a734f
fix(data-plane): preserve object-column identity through TQ wire
ZhiyuLi-Nvidia May 14, 2026
e6033a9
fix(data-plane): gate TQ write-back on TP×CP×PP leader to avoid dupli…
ZhiyuLi-Nvidia May 14, 2026
68110b9
chore: ruff auto-fix and D205 docstring fixes
ZhiyuLi-Nvidia May 14, 2026
06175ca
refactor(data-plane): drop async-grpo TQ scaffolding from sync PR
ZhiyuLi-Nvidia May 14, 2026
a592a0d
refactor(data-plane): consolidate producer codec, caller mints keys
ZhiyuLi-Nvidia May 14, 2026
b6227f1
test(data-plane): align codec tests with current contract
ZhiyuLi-Nvidia May 14, 2026
06fa8a3
refactor(grpo_sync): drop dead batch_cache; make TQPolicy attrs public
ZhiyuLi-Nvidia May 14, 2026
4d41c24
refactor(data-plane): extract calibration field filter into named sch…
ZhiyuLi-Nvidia May 15, 2026
2447264
refactor(data-plane): make kv_batch_get(select_fields) required
ZhiyuLi-Nvidia May 15, 2026
b5e4561
refactor(sync-rollout-actor): remove unused wrappers; document full l…
ZhiyuLi-Nvidia May 15, 2026
44e28d5
test(data-plane): move data_plane unit tests under tests/unit/ for CI…
ZhiyuLi-Nvidia May 15, 2026
d283cbb
test(data-plane): apply ruff --fix and import-sort to data_plane unit…
ZhiyuLi-Nvidia May 15, 2026
54b24b4
docs: fix broken nemo-gym Core Components link
ZhiyuLi-Nvidia May 15, 2026
8818e91
chore(grpo): drop stale mypy comments; rename TQPolicy ctor->actor
ZhiyuLi-Nvidia May 15, 2026
f6aaecf
fix(data-plane): reject loopback IP; resolve TQ runtime_env pin from …
ZhiyuLi-Nvidia May 15, 2026
efc0e27
docs(data-plane): rewrite README around sync flow + async proposal
ZhiyuLi-Nvidia May 15, 2026
8b535af
docs(data-plane): clarify partition scope and TQ mental model
ZhiyuLi-Nvidia May 15, 2026
f8b310d
refactor(data-plane): per-row tags on KVBatchMeta; rename slice → dri…
ZhiyuLi-Nvidia May 16, 2026
46b14a7
perf(sync-rollout-actor): subset driver_carry via carry_keys
ZhiyuLi-Nvidia May 16, 2026
1724362
refactor(grpo-sync): apply overlong filter post-dynamic-sampling
ZhiyuLi-Nvidia May 16, 2026
4b51983
refactor(grpo-sync): isolate TQ ops behind TQPolicy/KVBatchMeta façades
ZhiyuLi-Nvidia May 16, 2026
228f066
refactor(data-plane): YAML-only defaults for TQ config (terryk §9)
ZhiyuLi-Nvidia May 16, 2026
20f290e
docs(data-plane): refresh README around encapsulated TQ path
ZhiyuLi-Nvidia May 16, 2026
7f9f6ac
chore: ruff format + pyrefly ignore + underscore-md rename
ZhiyuLi-Nvidia May 16, 2026
52495d8
docs(data-plane): drop api-lifecycle doc; realistic concrete examples
ZhiyuLi-Nvidia May 16, 2026
9f88424
docs: align nemo-gym Core Components link with main
ZhiyuLi-Nvidia May 16, 2026
6c94851
fix(data-plane): close grad_norm collapse + NCCL desync in DP fsdp2 path
ZhiyuLi-Nvidia May 18, 2026
8100471
refactor(data-plane): drop _tq() lazy wrapper; fail-fast in check_con…
ZhiyuLi-Nvidia May 18, 2026
b51a4e4
refactor(grpo-sync): mint uids in rollout actor (verl-style per-promp…
ZhiyuLi-Nvidia May 18, 2026
c8ca43e
refactor(data-plane): rename KVBatchMeta.keys -> sample_ids (Phase A)
ZhiyuLi-Nvidia May 18, 2026
0f45f07
refactor(data-plane): rename DataPlaneClient kwarg keys -> sample_ids…
ZhiyuLi-Nvidia May 18, 2026
d68ad02
test(data-plane): update KVBatchMeta schema-pin to sample_ids
ZhiyuLi-Nvidia May 18, 2026
1ca91e8
refactor(data-plane): rename DataPlaneClient verbs kv_batch_* -> {put…
ZhiyuLi-Nvidia May 18, 2026
f047682
refactor(data-plane): tighten clear_samples(None) contract; warn on s…
ZhiyuLi-Nvidia May 18, 2026
aec314d
chore(data-plane): apply ruff format
ZhiyuLi-Nvidia May 18, 2026
65f8008
feat(data-plane): align seq-dim across DP ranks via meta-stamped glob…
ZhiyuLi-Nvidia May 18, 2026
ba3f2f8
test(data-plane): add missing DataPlaneConfig keys to test_seqpack_eq…
ZhiyuLi-Nvidia May 18, 2026
ac607de
refactor(data-plane): remove _PartitionRecord from TQ adapter
ZhiyuLi-Nvidia May 18, 2026
9b75e97
test(data-plane): remove empty tests/unit/data_plane/conftest.py
ZhiyuLi-Nvidia May 18, 2026
60a2872
revert(test): restore NUM_MINUTES=150 in prorlv2 recipe sh
ZhiyuLi-Nvidia May 18, 2026
8ca9e7a
test(data-plane): drop test_tq_multinode.py
ZhiyuLi-Nvidia May 18, 2026
55acd37
docs(data-plane): document DP-aligned forward pad seqlen in README
ZhiyuLi-Nvidia May 18, 2026
8289b5a
test(data-plane): drop stale import-isolation tests; merge codec_obje…
ZhiyuLi-Nvidia May 18, 2026
1ee5d2d
refactor(data-plane): drop drive-by edits from PR scope
ZhiyuLi-Nvidia May 19, 2026
2dd6ec0
test(data-plane): accept attribute-style data_plane access in invariant
ZhiyuLi-Nvidia May 19, 2026
bfab58a
refactor(data-plane): use attribute-style access on MasterConfig
ZhiyuLi-Nvidia May 19, 2026
ee4ce24
refactor(data-plane): replace run_grpo dispatch grep with behavioral …
ZhiyuLi-Nvidia May 19, 2026
bc46bf8
fix(data-plane): use attribute access for loss_fn KL penalty assert
ZhiyuLi-Nvidia May 19, 2026
f598ae6
fix(data-plane): pre-register fields to dodge TQ controller race
ZhiyuLi-Nvidia May 19, 2026
6225fbe
fix(configs): set truncated_importance_sampling_type=tis on recipes t…
ZhiyuLi-Nvidia May 19, 2026
c89bc43
refactor(data-plane): close four cross-boundary leaks
ZhiyuLi-Nvidia May 19, 2026
fd8b23a
chore(data-plane): apply ruff format to discard_samples
ZhiyuLi-Nvidia May 19, 2026
5c306ed
build: regenerate uv.lock (cu13 mooncake wheel needs requires-python …
ZhiyuLi-Nvidia May 19, 2026
ec13926
build: regenerate uv.lock against current HEAD
ZhiyuLi-Nvidia May 19, 2026
00ec2f5
test(data-plane): consolidate suite under tests/unit/data_plane
ZhiyuLi-Nvidia May 19, 2026
f61fec5
fix(data-plane): shrink mooncake_cpu segment defaults to fit CI runners
ZhiyuLi-Nvidia May 19, 2026
8528b7c
test(data-plane): update _apply_dynamic_sampling tests for policy= param
ZhiyuLi-Nvidia May 19, 2026
e0eb6cd
fix(data-plane): apply pad_to_seqlen to ALL 2D+ tensors in materialize
ZhiyuLi-Nvidia May 20, 2026
977a931
test(data-plane): add missing DataPlaneConfig keys to _TQ_CFG in chao…
ZhiyuLi-Nvidia May 20, 2026
511bd5b
test(data-plane): remove storage-actor-kill chaos test
ZhiyuLi-Nvidia May 20, 2026
8504cc8
fix(data-plane): exclude MESSAGE_LOG_BULK_FIELDS from FP8 calib request
ZhiyuLi-Nvidia May 20, 2026
6495ce8
test(data-plane): pin MESSAGE_LOG_BULK_FIELDS in DP_CALIB_EXCLUDED_FI…
ZhiyuLi-Nvidia May 20, 2026
a8d3816
test(data-plane): add missing DataPlaneConfig keys to tq_lifecycle fi…
ZhiyuLi-Nvidia May 20, 2026
b8b42d7
feat(data-plane): route FP8 KV scales through TQ (sync first cut)
ZhiyuLi-Nvidia May 20, 2026
7740ba0
Revert "feat(data-plane): route FP8 KV scales through TQ (sync first …
ZhiyuLi-Nvidia May 20, 2026
0f82cbf
refactor(data-plane): flip calib filter to positive include-list
ZhiyuLi-Nvidia May 20, 2026
b717610
test(data-plane): add realistic-shape rollout fixtures + cross-file d…
ZhiyuLi-Nvidia May 20, 2026
4b8c9fe
build: refresh uv.lock
ZhiyuLi-Nvidia May 20, 2026
7f5cfa9
chore(test): apply ruff isort + blank-line fixes
ZhiyuLi-Nvidia May 20, 2026
cacb612
fix(data-plane): override _is_writeback_leader in DTensor V1 worker
ZhiyuLi-Nvidia May 20, 2026
4d61ea3
test(data-plane): sync grpo_math_1B reference config buffer sizes
ZhiyuLi-Nvidia May 20, 2026
a99aa6a
test(data-plane): slim test_architecture_invariants to 2 behavioral t…
ZhiyuLi-Nvidia May 20, 2026
a3ec982
undo unnecessary change
ZhiyuLi-Nvidia May 20, 2026
dce67a4
build: resolve mooncake-transfer-engine-cuda13 from PyPI instead of G…
ZhiyuLi-Nvidia May 21, 2026
42be993
perf(data-plane): skip Ray return of per-token logprob tensors
ZhiyuLi-Nvidia May 21, 2026
17fbf5c
perf(data-plane): worker-side suppress per-token logprob Ray return
ZhiyuLi-Nvidia May 21, 2026
6b1dff8
refactor(data-plane): drop aggregator path now that logprob workers r…
ZhiyuLi-Nvidia May 21, 2026
24cb5bf
refactor(data-plane): make Ray worker_coords the single source of tru…
ZhiyuLi-Nvidia May 22, 2026
abcea99
Revert "refactor(data-plane): make Ray worker_coords the single sourc…
ZhiyuLi-Nvidia May 22, 2026
5815041
fix(data-plane): unify leader-gate on NamedSharding.is_axis_zero; fix…
ZhiyuLi-Nvidia May 22, 2026
fedc88e
chore: ruff auto-fix and ruff-format pass post-rebase
ZhiyuLi-Nvidia May 22, 2026
254984e
build: refresh uv.lock against post-rebase pyproject.toml
ZhiyuLi-Nvidia May 22, 2026
92499b3
undo unnecessary change
ZhiyuLi-Nvidia May 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,19 @@ logger:
cluster:
gpus_per_node: 1
num_nodes: 1

# TransferQueue-mediated data plane for sync GRPO.
# Off by default — the legacy grpo_train trainer never engages this.
# Flip enabled=true and run grpo_train_sync to use TQ-mediated bulk
# transfer between rollout and train. See nemo_rl/data_plane/README.md.
data_plane:
enabled: false
impl: transfer_queue
backend: "simple" # TQ storage backend ('simple' or 'mooncake_cpu')
storage_capacity: 1000000 # max samples retained per partition
num_storage_units: 2 # storage shards
claim_meta_poll_interval_s: 0.5 # blocking-claim poll cadence
global_segment_size: 549755813888 # 512 GiB — used when backend == "mooncake_cpu"
local_buffer_size: 68719476736 # 64 GiB — used when backend == "mooncake_cpu"
# observability: # NotRequired
# enabled: false
Comment thread
ZhiyuLi-Nvidia marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ loss_fn:
reference_policy_kl_penalty: 0.0
use_importance_sampling_correction: true
truncated_importance_sampling_ratio: 2
truncated_importance_sampling_type: tis
checkpointing:
checkpoint_dir: results/grpo-glm47-flash-4n8g-automodel
policy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ loss_fn:
reference_policy_kl_penalty: 0.0
use_importance_sampling_correction: true
truncated_importance_sampling_ratio: 2
truncated_importance_sampling_type: tis
ratio_clip_max: 0.28
ratio_clip_c: 10
checkpointing:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ loss_fn:
reference_policy_kl_penalty: 0.0
use_importance_sampling_correction: true
truncated_importance_sampling_ratio: 2
truncated_importance_sampling_type: tis
checkpointing:
checkpoint_dir: results/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16
policy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ loss_fn:
reference_policy_kl_penalty: 0.0
use_importance_sampling_correction: true
truncated_importance_sampling_ratio: 2
truncated_importance_sampling_type: tis
checkpointing:
checkpoint_dir: results/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-megatron-ep16
policy:
Expand Down
46 changes: 41 additions & 5 deletions examples/run_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@
from nemo_rl.utils.logger import get_next_experiment_dir


def _select_trainer(master_config: MasterConfig):
"""Pick the synchronous trainer based on ``data_plane.enabled``.

Factored out so test_architecture_invariants can verify dispatch
without the full setup() path.
"""
dp_cfg = master_config.data_plane or {}
if dp_cfg.get("enabled", False):
from nemo_rl.algorithms.grpo_sync import grpo_train_sync

print("🚀 Running synchronous GRPO training (TransferQueue)")
return grpo_train_sync
print("🚀 Running synchronous GRPO training (legacy)")
return grpo_train


def parse_args() -> tuple[argparse.Namespace, list[str]]:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Run GRPO training with configuration")
Expand Down Expand Up @@ -100,6 +116,20 @@ def main() -> None:
val_task_to_env,
) = setup_response_data(tokenizer, config.data, config.env)

# Pick the policy factory at the launcher level so the legacy trainer
# stays data-plane-agnostic (architectural invariant — see
# tests/data_plane/unit/test_architecture_invariants.py).
_dp_cfg = config.data_plane or {}
if _dp_cfg.get("enabled", False):
from nemo_rl.models.policy.tq_policy import TQPolicy

def _make_policy(**kwargs):
return TQPolicy(**kwargs, dp_cfg=_dp_cfg)

_policy_factory = _make_policy
else:
_policy_factory = None # setup() defaults to plain Policy

(
policy,
policy_generation,
Expand All @@ -111,7 +141,13 @@ def main() -> None:
checkpointer,
grpo_state,
master_config,
) = setup(config, tokenizer, dataset, val_dataset)
) = setup(
config,
tokenizer,
dataset,
val_dataset,
policy_factory=_policy_factory,
)

# Check if async mode is enabled
if "async_grpo" in config.grpo and config.grpo["async_grpo"]["enabled"]:
Expand Down Expand Up @@ -165,10 +201,10 @@ def main() -> None:
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
)
else:
print("🚀 Running synchronous GRPO training")

# Run standard GRPO training
grpo_train(
# Two parallel synchronous trainers (verl-style — main_ppo.py vs
# main_ppo_sync.py). data_plane.enabled selects which one runs.
trainer = _select_trainer(master_config)
trainer(
policy,
policy_generation,
dataloader,
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def distillation_train(
student_generation = student_policy # type: ignore
NEED_REFIT = False
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
assert student_generation is not None # for mypy type check
assert student_generation is not None

# common config/state items
current_epoch = distillation_save_state["current_epoch"] # current epoch
Expand Down
14 changes: 11 additions & 3 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast
from typing import Any, Callable, NotRequired, Optional, TypedDict, TypeVar, cast

import numpy as np
import ray
Expand Down Expand Up @@ -59,6 +59,7 @@
get_keys_from_message_log,
)
from nemo_rl.data.utils import extract_necessary_env_names, load_dataloader_state
from nemo_rl.data_plane.interfaces import DataPlaneConfig
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
Expand Down Expand Up @@ -207,6 +208,7 @@ class MasterConfig(BaseModel, extra="allow"):
logger: GRPOLoggerConfig
cluster: ClusterConfig
checkpointing: CheckpointingConfig
data_plane: Optional[DataPlaneConfig] = None


# ===============================================================================
Expand All @@ -220,6 +222,7 @@ def setup(
dataset: AllTaskProcessedDataset | dict[str, AllTaskProcessedDataset],
val_dataset: Optional[AllTaskProcessedDataset],
processor: Optional[AutoProcessor] = None,
policy_factory: Optional[Callable[..., ColocatablePolicyInterface]] = None,
) -> tuple[
ColocatablePolicyInterface,
Optional[GenerationInterface],
Expand Down Expand Up @@ -580,10 +583,15 @@ def init_train_dataloader(dataset, suffix: str = ""):
"(reference model is not loaded)."
)

# Caller-supplied factory lets the sync trainer swap in a TQ-mediated
# Policy subclass without this shared setup needing to know the data
# plane exists. Default is the plain Policy class — legacy behavior.
_make_policy = policy_factory if policy_factory is not None else Policy

def init_policy():
"""Initialize policy training workers."""
t0 = time.perf_counter()
p = Policy(
p = _make_policy(
cluster=train_cluster,
config=policy_config,
tokenizer=tokenizer,
Expand Down Expand Up @@ -1360,7 +1368,7 @@ def grpo_train(
policy_generation = policy # type: ignore
NEED_REFIT = False
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
assert policy_generation is not None # for mypy type check
assert policy_generation is not None

# Check if we need to sync KV cache scales
# When fallback to policy as the policy_generation, we use getattr to check.
Expand Down
Loading
Loading