Skip to content

fix(megatron): delegate packed CP slicing to MCore#2445

Open
zyzhou5 wants to merge 15 commits into
NVIDIA-NeMo:zhiyul/data_plane_planfrom
zyzhou5:cp-broadcast-r3-prep
Open

fix(megatron): delegate packed CP slicing to MCore#2445
zyzhou5 wants to merge 15 commits into
NVIDIA-NeMo:zhiyul/data_plane_planfrom
zyzhou5:cp-broadcast-r3-prep

Conversation

@zyzhou5
Copy link
Copy Markdown

@zyzhou5 zyzhou5 commented May 8, 2026

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

ZhiyuLi-Nvidia and others added 15 commits May 1, 2026 14:40
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Both rl-arena and verl converge on driver-balanced metadata + worker-side
direct fetch (1-hop). Plan updates:

- Header reframed: rl-arena and verl as co-references (same idea, different
  worker plumbing). NeMo-RL adopts verl's @tqbridge decorator.
- Stage 4: corrected LOC estimate (~150-250, not 400-600). shard_keys_by_seqlen
  uses sort-by-seqlen + stride (matches rl-arena's shard_for_dp and NeMo-RL's
  dynamic_batching_args branch). Single algorithm, no strategy parameter.
- Stage 4: TP/CP/PP guidance — broadcast inside the group, not per-sibling
  fetch. CP sequence-dim slicing happens in model forward, not data plane.
- Stage 3 lifecycle: corrected ordering (prev_lp + ref_lp + mask before
  advantage; KL-in-reward needs both logprobs).
- Stage-completion design: field-presence is the natural ready signal;
  mark_consumed dropped from public ABC (TQ advances inside get_meta(fetch)).
- KVBatchMeta mirrors transfer_queue.metadata.KVBatchMeta 1:1 (fields,
  not fields_available).
- ABC adds direct-by-key kv_batch_get / kv_batch_put / kv_clear.
- TQ pinned to 0.1.5 (matches local wheel); pyproject packages.find fix
  so nemo_rl.data_plane gets installed.
- New risks: R11 (dynamic sampling/DAPO), R12 (message_log Tier-1/3 split),
  R13 (stage completion / fault tolerance).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… balanced packing

Adds an optional data-plane layer that routes GRPO train data through
TransferQueue (Ray-actor-backed KV store) instead of Ray's in-memory
object store. Mirrors verl's main_ppo.py / main_ppo_sync.py split:
algorithms/grpo.py is unchanged; algorithms/grpo_sync.py is a TQ-only
sibling dispatched when data_plane.enabled=true.

Key pieces:
- nemo_rl/data_plane/: stable adapter boundary (DataPlaneClient ABC,
  KVBatchMeta), TQ adapter, codec, sharder, observability middleware.
- @dp_dispatch decorator: makes Policy methods polymorphic over
  BatchedDataDict (legacy) and KVBatchMeta / list[KVBatchMeta] (TQ).
- Driver-side balanced packing: when sequence packing or dynamic
  batching is on, shard_by_batch_size must be called once on the
  driver with shards=DP_world — bin_count_multiple=DP_world is what
  keeps per-DP n_microbatches uniform. Per-shard packing metadata
  rides in KVBatchMeta.extra_info; train_presharded reattaches it
  post-fetch and skips local repack. Without this, per-rank shards=1
  packing produced different n_microbatches across DP groups and
  Megatron deadlocked at the first cross-DP collective (10-min NCCL
  watchdog at step 4 in our 2-node qwen3-30b runs).

Verification:
- Unit (5/5): dispatch decorator handles BatchedDataDict / KVBatchMeta /
  list[KVBatchMeta], rejects size mismatches, etc.
- Functional (3/3): legacy and TQ paths produce byte-identical sharded
  data + packing metadata for seqpack / dynbatch / no-packing — proves
  the data plane is a lossless transport, isolated from NCCL noise.
- E2E: qwen3-30b mcore GRPO 5/5 steps green for baseline-TQ, seqpack-TQ,
  and dynbatch-TQ on 2 nodes (16 GPUs).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ard helpers

Pulls the driver-side balanced-packing + per-rank fan-out block out of
grpo_sync.py:605-704 into nemo_rl/data_plane/preshard.py so the same
two operations can be reused by future async data-plane trainers
without duplicating the bin_count_multiple=DP_world incantation.

The original block had two distinct concerns inlined together:
  1. Compute pre-shards from train_data via shard_by_batch_size with
     packing args derived from policy_cfg (pure transform, no I/O).
  2. For each pre-shard, kv_batch_put seed fields and build a
     KVBatchMeta with packing metadata in extra_info (TQ I/O).

Split into:
  - driver_balanced_preshards(train_data, dp_world, policy_cfg)
      → list[BatchedDataDict]
  - fan_out_per_rank_metas(pre_shards, dp_client, partition_id,
                           task_name, key_prefix, seed_fields)
      → list[KVBatchMeta]

key_prefix is the only behavioural parameter: sync GRPO passes
f"step{total_steps}", future async path will pass
f"v{wv}_step{step}". Field iteration order, .detach().contiguous()
calls, and KVBatchMeta construction order are byte-identical to the
inline version — the refactor preserves the exact balanced-packing
semantics that prevent Megatron from deadlocking on the first
cross-DP collective when sequence packing / dynamic batching is on
(commit a085559 described the 10-min NCCL watchdog at step 4).

Touches:
  - nemo_rl/data_plane/preshard.py (new, 162 lines): two helpers,
    distinct from sharding.py which is metadata-only sort-by-seqlen
    for the @dp_dispatch default fan-out.
  - nemo_rl/algorithms/grpo_sync.py (-113 / +21 net): inline block
    replaced with two helper calls; dead imports (asyncio,
    tensordict.TensorDict, KVBatchMeta) removed.
  - tests/data_plane/unit/test_architecture_invariants.py
    (R-C9 invariant): the regex check 'KVBatchMeta(' now accepts
    delegation via 'fan_out_per_rank_metas(' as well, with a
    chained check that the helper itself constructs KVBatchMeta so
    the dispatch chain to the TQ branch isn't silently broken.

Verification:
  - Tier 1 unit (data_plane): 56/56 passed (Python 3.13.13,
    nightly nemo-rl image).
  - Tier 2 functional (data_plane): 4 passed, 1 skipped — including
    test_seqpack_legacy_equals_tq, test_dynbatch_legacy_equals_tq,
    test_no_packing_legacy_equals_tq (all three byte-equality
    parity tests against the legacy inline path).
  - E2E: qwen3-30b mcore GRPO seqpack-TQ run past step 3 with no
    NCCL deadlock, validating the bin_count_multiple invariant
    survives the helper extraction.

Companion doc: research/data_plane_async_rl_limitations.md §5.4
explains why these helpers belong on the data-plane boundary rather
than in the algorithms layer (TQ I/O is data-plane concern, packing
is reused across sync and async trainers).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds optional ``dp_cfg`` to ``ReplayBuffer.__init__`` and a
consume-time ``kv_clear`` pass inside ``sample()``. When the buffer
holds ``KVBatchMeta`` references (planned async-on-TQ path, not yet
wired up in any caller), every popped trajectory has its TQ keys
cleared before ``sample()`` returns — same buffer lock as the pop, so
no fresh push can race a delete (Race 5 footgun).

Without this, the proposed async-on-TQ design has a memory leak in the
TQ controller at training throughput rate: ``num_prompts`` keys per
step never reclaimed, since (a) ``clear_partition`` is unconditional
(no consumer-tracking GC) and (b) the existing eviction-on-overflow
path in ``push_with_wait_signal`` only fires for *rejected* writes,
not consumed ones. See
research/data_plane_async_rl_limitations.md §5.9 Race 1 for the full
math (linear leak; not survivable).

Backward-compatible: the ``dp_cfg=None`` default preserves the
in-memory async path byte-for-byte. ``isinstance(item, KVBatchMeta)``
guards the new clear loop, so dict trajectories pass through
untouched. The lazy ``_ensure_dp_client()`` builder defers the
data-plane import until first use, keeping the in-memory path free of
data-plane dependencies.

Wiring: callers must pass ``dp_cfg=master_config["data_plane"]`` when
constructing ``ReplayBuffer.options(...).remote(...)``. No call site
does this yet — that lands with the async-on-TQ trainer (PR 4).
``bootstrap=False`` is passed to ``build_data_plane_client`` so the
buffer attaches to the driver-bootstrapped controller rather than
trying to spin up a second named actor.

Stale-version GC (Race 2 in the same doc) is *not* part of this PR —
the existing ``sample()`` already raises ValueError for trajectories
older than min_valid_version (line 142-145 pre-edit), and silently
GCing them would suppress a legitimate error signal. Defer until the
TQ producer (PR 3) can actually generate stale metas under refit.

Verification: in-memory async path is unaffected by construction
(``dp_cfg=None`` short-circuits all new code paths). Existing tests in
tests/unit/algorithms/test_async_utils.py construct
``ReplayBuffer.remote(max_size=...)`` without ``dp_cfg``; all
selectors and assertions there continue to hold.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… enabled

Producer-side hook for the planned async-on-TQ path. When ``dp_cfg``
is set on ``AsyncTrajectoryCollector``, the rollout's ``final_batch``
is tensorized into the TQ partition ``rollouts`` and a
``KVBatchMeta`` reference is pushed onto the buffer instead of the
in-memory dict. Pairs with PR 2 (ReplayBuffer clears the meta's TQ
keys on consume) and the upcoming PR 4 (grpo_async_dp.py — trainer
materializes per consumed batch and fans out via preshard.py).

Mechanics:
  - Keys: f"v{wv}_p{prompt_idx}_g{i}" — versioned namespace so the
    same prompt at different weight versions can't collide; trainer
    can later filter by ``tag.version`` if needed.
  - Tags: ``[{"version": wv}] * n_samples`` for each put. The version
    is duplicated on every key in the batch but each tag dict is the
    same object reference; serializer dedupes.
  - Fields: every ``torch.Tensor`` leaf of ``final_batch_cpu`` is
    written. The trainer side picks which to fetch via
    ``select_fields`` rather than constraining what the producer
    writes — keeps the producer schema-agnostic.
  - extra_info: rollout_metrics + timestamp ride on the meta so the
    trainer's per-step bookkeeping survives the TQ round-trip without
    a side channel.

``asyncio.run(client.kv_batch_put(...))`` is safe here because
``_collection_loop`` is a worker thread without an enclosing event
loop (Race 3 in the limitations doc; the running-loop conflict only
fires when there's already an asyncio loop in the calling thread).

Backward-compat: ``dp_cfg=None`` default — the in-memory async path
is byte-for-byte unchanged. The ``client = self._ensure_dp_client()``
guard short-circuits all new code when the data plane isn't enabled.
``bootstrap=False`` so the collector attaches to the driver's
controller rather than spinning up a second named actor.

Producer-owned rollback (kv_clear when push_with_wait_signal returns
"full") is *not* part of this PR. The current loop retries with
exponential backoff on "full" rather than dropping — kv_clear in that
path would lose data we just wrote. The shutdown-with-pending-meta
edge case (cluster ends while a put is in-flight) is left as a known
leak for now; TQ partitions are ephemeral per cluster, so it doesn't
accumulate across runs.

No call site passes ``dp_cfg`` yet — the wiring at
``algorithms/grpo.py:2527`` (the trainer_collector.options(...).remote
construction) lands in PR 4 alongside the dispatch in run_grpo.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ed packing

Lights up async-on-TQ as a callable path:

  * ReplayBuffer.sample materializes any popped KVBatchMeta into the
    dict format ``async_grpo_train`` expects ({"batch", "rollout_metrics",
    "timestamp"}). Materialize+clear stays under the buffer lock —
    Race 5: keys are versioned so collisions are unlikely, but the lock
    is the cheap correctness invariant. Pairs with PR 2's clear-on-
    consume.

  * async_grpo_train reads master_config["data_plane"]; if enabled,
    bootstraps the TQ controller on the driver, captures the client
    handle (``_dp_client``), and threads ``dp_cfg`` to both
    ReplayBuffer and AsyncTrajectoryCollector at construction
    (bootstrap=False on the actor side).

  * At the policy.train call site, async_grpo_train now branches: when
    the client is set, drive the same balanced packing + per-rank
    fan-out as grpo_sync (driver_balanced_preshards +
    fan_out_per_rank_metas, key_prefix=f"v{wv}_s{step}"), call
    policy.train(list[KVBatchMeta]) — the @dp_dispatch list path with
    is_meta_list=True (dispatch.py:116-127), and kv_clear the train
    partition before the next step. This is the same bin_count_multiple
    invariant a085559 added for sync; without it, async + sequence
    packing would deadlock at the first cross-DP collective the same
    way sync did pre-a085559c.

  * Hoist DP_SEED_FIELDS from grpo_sync.py to nemo_rl/data_plane/
    preshard.py — both trainers now import the canonical schema. Test
    fixture in tests/data_plane/functional/test_seqpack_equivalence.py
    keeps its own copy on purpose (testing the wire schema as a
    contract, not the producer constant).

Why ``list[KVBatchMeta]`` and not single ``KVBatchMeta``:
    The single-meta path runs the @dp_dispatch sharder
    (shard_keys_by_seqlen) which sorts by seqlen and strides — that
    reorders samples vs. ``meta.keys`` order and skips the policy-aware
    sharding semantics (no GBS check, no FLOPs recording, no
    sequence-packing validation). The list-of-metas path skips the
    sharder entirely and uses the driver's pre-balanced layout.

Known gaps (NOT fixed here, follow-up):
  * FLOPs reporting is silently dropped on the @dp_dispatch list
    path. Lives in lm_policy.train's body (lm_policy.py:730-742) which
    the decorator skips when input is meta-shaped. Affects both
    grpo_sync (since a085559) and now the async-on-TQ path. Right fix
    is a _dp_post_train post-aggregator hook on the decorator —
    landing as a separate PR. ``policy.get_logprobs(KVBatchMeta)``
    has its own ordering bug (sharder reorders, aggregator concats in
    rank order) but async never goes through that path; flagged for
    documentation only.
  * Two TQ round-trips per async step (rollouts partition →
    materialize → train partition → workers). Necessary because the
    trainer needs the assembled BatchedDataDict for reward / advantage
    computation between the two TQ stages. Future optimization can
    fuse if reward/advantage move to the workers.

Backward-compat: when data_plane.enabled is unset/false, async path
behavior is byte-for-byte unchanged — _dp_client stays None, the new
branch isn't taken, ReplayBuffer / AsyncTrajectoryCollector get
dp_cfg=None and short-circuit all data-plane code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…patch TQ path

Closes Issues NVIDIA-NeMo#3 and NVIDIA-NeMo#4 raised in PR review of the data-plane stack.

Issue NVIDIA-NeMo#3 — single-``KVBatchMeta`` path returned rows in scrambled order.
``shard_keys_by_seqlen`` sorts by sequence length and strides
(``order[r::dp_world_size]``) to balance per-rank token totals. The
worker logprob aggregators (``_aggregate_logprob_results``) then
concatenate per-rank outputs in rank order via
``BatchedDataDict.from_batches`` — without inverting the seqlen-
strided permutation. Result: ``policy.get_logprobs(KVBatchMeta(...))``
returned rows in
[order[0], order[d], order[2d], …, order[1], order[1+d], …]
order, not the caller's ``meta.keys`` order. Silent correctness bug
(test_seqpack_legacy_equals_tq didn't catch it because the sync path
calls ``policy.get_logprobs(BatchedDataDict)`` — legacy passthrough,
no sharder).

Fix:
  * ``shard_keys_by_seqlen`` records ``_dp_original_indices`` per
    shard in ``extra_info`` (the ``idx`` list it computed).
  * ``dp_dispatch`` reconstructs the concat-position → input-index
    permutation from the shards' ``extra_info``, then applies the
    inverse via ``BatchedDataDict.reorder_data`` after ``aggregate``.
  * The reorder is gated on ``is_meta and not is_meta_list`` — for
    ``list[KVBatchMeta]`` the driver controls ordering (PR 0
    ``fan_out_per_rank_metas``) and the decorator must not undo it.
  * Skipped silently if the result isn't a BatchedDataDict (e.g.
    ``train`` returns a plain dict — order doesn't apply).

Issue NVIDIA-NeMo#4 — TQ path silently dropped legacy training semantics.
The decorator's TQ branch returns ``aggregate(results)`` directly
and never enters ``Policy.train``'s body — so the FLOPs accumulation
at lm_policy.py around the ``flops_tracker`` block, plus the
``num_ranks`` and ``theoretical_tflops`` fields, were missing from
results when the trainer called ``policy.train(KVBatchMeta)`` or
``policy.train(list[KVBatchMeta])``. Same gap for the missing GBS /
DP divisibility assertion.

Fix (additive — no signature changes to the existing aggregate
callables):
  * ``dp_dispatch`` adds a basic divisibility assertion on the TQ path:
    ``total_meta_size % dp_size == 0`` (legacy path enforces this via
    ``shard_by_batch_size(batch_size=gbs)``; TQ path skips that call
    site).
  * ``dp_dispatch`` looks up ``self._dp_post_<method_name>`` after
    ``aggregate``. If defined, calls
    ``post(aggregated, raw_results, shards=shards)`` and uses its
    return value. Convention-based — opt-in per Policy method, no
    decorator boilerplate.
  * ``Policy._dp_post_train`` recovers FLOPs from ``meta.sequence_lengths``
    on each shard (driver-pre-balanced for ``list[KVBatchMeta]``,
    sharder-strided for single ``KVBatchMeta``), records ``total_flops``,
    ``num_ranks``, ``theoretical_tflops`` — same fields the legacy
    body produces.

Backward-compat: existing tests in tests/data_plane/unit/test_shard_parity.py
and test_dispatch.py don't check ``extra_info`` shape on sharder output
or assert on aggregate-method return type other than what's already
returned, so the additive fields and gated reorder are transparent.
The legacy ``policy.train(BatchedDataDict)`` path is unchanged — it
keeps building results inline and never enters the new hook.

Async-on-TQ (PR 4) and grpo_sync (PR 0) both use the
``list[KVBatchMeta]`` path, so they inherit the FLOPs fix automatically
via the post-hook. The reorder fix is only meaningful for callers
that pass single ``KVBatchMeta`` — primarily future logprob/reference-
logprob TQ wiring; flagged in commit message of NVIDIA-NeMo#3 above.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…spatch list path

Migrates ``policy.get_logprobs`` and ``policy.get_reference_policy_logprobs``
in ``grpo_sync.py`` from the legacy in-memory ``BatchedDataDict`` body
onto the @dp_dispatch ``list[KVBatchMeta]`` path that train (PR 0)
already uses. Activates the partition's pre-declared ``"prev_lp"`` /
``"ref_lp"`` consumer tasks (line 435) which until now were
reservations the original ``a085559c`` author left for future work.

Why this is safe (and why we don't need the bin_count_multiple
invariant the train path needed):

  Megatron's training step has cross-DP collectives per microbatch —
  gradient sync — so DP ranks lockstep on each microbatch. Different
  per-rank n_microbatches → first-finished rank hangs on the next
  collective (the step-4 NCCL deadlock from ``a085559c``).
  Logprob INFERENCE has no such collective: forward-only, no backward,
  no gradient sync. TP/PP collectives stay within (TP×PP) groups; DP
  ranks don't lockstep through microbatches. So per-rank packing
  variation is fine — slowest rank just takes longer, no deadlock.

  This is exactly why ``train_presharded`` reattaches
  ``meta.extra_info`` packing metadata (driver pre-balanced, must
  override worker's local re-pack) but ``get_logprobs_presharded`` does
  not (worker's local re-pack is fine). a085559's commit message
  documented this distinction; this commit relies on it.

So no worker-side changes are needed. The migration is purely driver-
side:

  before:
    train_data["prev_logprobs"] = policy.get_logprobs(
        BatchedDataDict({...}), timer=timer
    )["logprobs"]
  after:
    sharded, unsorted = logprob_data.shard_by_batch_size(
        dp_world, batch_size=None, sequence_packing_args=spa,
    )                       # policy-aware shard, same args as legacy
                            # body lines 426-450, with logprob_mb_tokens
    metas = fan_out_per_rank_metas(
        sharded, dp_client=..., partition_id="train",
        task_name="prev_lp", key_prefix=f"step{N}_lp",
        seed_fields=("input_ids", "input_lengths", "token_mask",
                     "sample_mask"),
    )                       # PR 0 helper, reused
    out = policy.get_logprobs(metas, timer=timer)
                            # @dp_dispatch is_meta_list=True — skips
                            # sharder, dispatches, aggregator concats.
    if seqpack or dynbatch:
        out.reorder_data(unsorted)
                            # mirrors legacy body line 478-479: the
                            # driver's shard_by_batch_size returned the
                            # same unsorted_data_indices it always has;
                            # we just apply it on the caller side.
    train_data["prev_logprobs"] = out["logprobs"]

Same flow for ``get_reference_policy_logprobs`` under a distinct
task_name + key_prefix so the per-rank fan-out keys don't collide with
the prev_lp fan-out's keys (or the train fan-out's later in the same
step). The single end-of-step ``kv_clear(keys=None, partition_id="train")``
(line ~967) wipes all three namespaces atomically — no GC plumbing
needed.

What this does NOT do:
  * No worker changes — ``get_logprobs_presharded`` and
    ``get_reference_policy_logprobs_presharded`` keep their existing
    bodies (``self._fetch(meta)`` then call legacy worker-internal
    method). Their local re-pack inside ``_fetch`` is correct for
    forward-only inference; see commit-message above.
  * No legacy ``Policy.get_logprobs(BatchedDataDict)`` body changes.
    The legacy passthrough is intact and unchanged for any other
    caller still passing BatchedDataDict.
  * No @dp_dispatch decorator changes. Reuses the existing list-path
    that train already exercises.
  * Multimodal data is dropped from the logprob input on the TQ path
    (P3 — tensor-only on the bus). Matches pre-existing behaviour of
    the train fan-out which already filters multimodal out of
    train_data via ``_DP_SEED_FIELDS``.

Verification: passed PR 0's qwen3-30b mcore seqpack run end-to-end is
the production signal. After this commit, every grpo_sync run with
seqpack/dynbatch on exercises the @dp_dispatch list path for prev_lp
*and* ref_lp every step — three distinct DP-balanced fan-outs per
step into the same TQ partition.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…d leader-broadcast fetch

Retire the @dp_dispatch decorator and migrate TQ-mediated dispatch into
a dedicated nemo_rl/models/policy/tq_policy.py:TQPolicy(Policy) subclass.
The legacy in-memory Policy and grpo.py are now untouched by data-plane
code; the TQ wiring (controller bootstrap, partition register, fan-out,
drain, close) is fully encapsulated in TQPolicy. examples/run_grpo.py
selects TQPolicy + grpo_train_sync when data_plane.enabled=True, legacy
Policy + grpo_train otherwise.

Adds leader-broadcast fetch policy in AbstractPolicyWorker._fetch:
- New default fetch_policy="auto" auto-detects via _get_replica_group():
  if CP > 1, leader of (TP×CP×PP) siblings fetches once and broadcasts
  the BatchedDataDict over NCCL; otherwise every rank fetches
  independently from TQ (TP=CP=PP=1, the cheapest path).
- _broadcast_batched_data_dict ships a shape descriptor via
  broadcast_object_list, then per-tensor broadcast on the group's
  backend device (NCCL → CUDA, gloo → CPU).
- _attach_or_repack_pack_metadata reattaches driver-side packing
  metadata (micro_batch_indices/micro_batch_lengths) for all three
  *_presharded entry points so seqpack TQ runs don't crash on
  data.micro_batch_indices[0].

Verified end-to-end:
- qwen3-30B-A3B mcore + seqpack + CP=1: 10/10 steps
- qwen3-30B-A3B mcore + seqpack + CP=2 + independent: 10/10 steps
- qwen3-30B-A3B mcore + seqpack + CP=2 + auto leader_broadcast: 10/10 steps,
  KL parity vs independent baseline within last-decimal jitter
- llama-3.1-8B DTensor + seqpack + CP=1: 10/10 steps

Architecture invariants tightened:
- legacy nemo_rl/algorithms/grpo.py has zero data_plane / TransferQueue /
  KVBatchMeta / dp_dispatch tokens (regex-checked)
- nemo_rl/algorithms/grpo_sync.py guards on hasattr(policy, "dp_cfg")
  rather than feature-gating on master_config
- 18/18 architecture invariant tests + 2 new leader_broadcast tests pass

Removed dead code: nemo_rl/data_plane/dispatch.py (the decorator),
nemo_rl/data_plane/sharding.py (its sharder), tests/data_plane/unit/
test_dispatch.py and test_shard_parity.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Apply 4 fixes from external review:

1. **Multimodal extras drop (correctness).** ``fan_out_per_rank_metas``
   now writes any tensor field present in the shard, not just those in
   ``seed_fields``. The legacy in-memory path passes the full
   BatchedDataDict; the TQ path was dropping VLM extras like
   ``pixel_values`` because the field filter was schema-restricted. The
   real TQ adapter creates partitions implicitly on first put (per
   adapter comment), so extras don't fight schema registration.

2. **Per-rank ``asyncio.run`` loop (scaling).** Replace the loop of
   per-shard ``asyncio.run(kv_batch_put(...))`` with a single
   ``asyncio.gather`` over all shards. Adds ``fan_out_per_rank_metas_async``
   and a sync façade. O(1) RTT instead of O(DP).

3. **Cleanup on worker failure.** Wrap ``TQPolicy.train``'s fan-out +
   dispatch in try/finally so the partition is drained even if a worker
   raises. Stale tensors no longer accumulate across failed steps.

4. **Schema consolidation.** Move ``_LP_SEED_FIELDS`` from
   ``tq_policy.py`` into ``preshard.py:LP_SEED_FIELDS`` next to
   ``DP_SEED_FIELDS``. Single source of truth for the canonical seed
   sets.

Adds ``tests/data_plane/unit/test_preshard_extras.py`` covering: tensor
extras auto-included, non-tensor entries skipped, LP⊆DP invariant,
per-rank key namespacing.

Deferred to follow-up issues (out of this PR's scope):
- async TQ key collision risk in ``async_utils.py`` (pre-existing)
- partial ``kv_clear`` invalidates ``seen_keys`` in the TQ adapter
  (latent — only ``keys=None`` full-clear is exercised today)

Architecture invariants 18/18 still pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Companion to data_plane_integration_plan.md: documents the runtime view
(call order, payloads, per-step RPC counts) of the sync 1-hop GRPO path,
and contrasts it with verl's main_ppo_sync.py at the integration-shape
level (per-prompt actors + ReplayBuffer vs batched actor + slice-only
driver).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…fecycle

Land the sync GRPO data-plane refactor end-to-end:

- New SyncTrajectoryCollector (algorithms/sync_utils.py) — sibling of
  AsyncTrajectoryCollector. Owns rollout + flatten/mask + prompt
  extraction + flat kv_batch_put. Driver receives only KVBatchMeta +
  small per-sample slice.
- rollout_to_tq helper colocated in sync_utils.py (single first-write
  primitive; mirrors verl main_ppo_sync.py:386-423).
- driver_io.read_columns / write_columns helpers for driver-side
  delta read/write on metas.
- Register SyncTrajectoryCollector under VLLM env tier so multinode
  Ray workers provision tensordict.
- grpo_sync.py rewires logprob/ref/train through shard_meta_for_dp
  per-DP fan-out + worker leader-only write-back; driver reads
  small slices only (advantages, log_data input_ids).

Validated e2e on mcore-1B + seqpack + CP=1 (job 11610072,
20/20 steps, +0.21 s/step vs legacy, bit-exact through step 7).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sync 1-hop simplify pass driven by /simplify review.

- nemo_rl/utils/venvs.py: add make_actor_runtime_env(fqn) — wraps the
  get_actor_python_env + create_local_venv_on_each_node + os.environ
  wiring that was duplicated 3× across grpo.py and grpo_sync.py.
  Touches only the new helper; legacy grpo.py inline blocks
  intentionally untouched (per "grpo.py is 100% backward compatible").
- nemo_rl/algorithms/grpo_sync.py: use the helper for SyncTrajectoryCollector
  runtime_env (~20 lines → ~3); switch _apply_dynamic_sampling's
  pending_unfiltered_rewards from O(N²) [*xs, y] to O(1) .append(y);
  drop rotted (grpo.py:878) line-ref comment; clean up orphan imports.

Tier-1 unit tests: 86/86 passing (job 11623540).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@zyzhou5 zyzhou5 requested review from a team as code owners May 8, 2026 18:27
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ZhiyuLi-Nvidia ZhiyuLi-Nvidia force-pushed the zhiyul/data_plane_plan branch from bff0471 to d20a6ed Compare May 9, 2026 01:15
@ZhiyuLi-Nvidia ZhiyuLi-Nvidia requested review from a team as code owners May 9, 2026 01:15
@ZhiyuLi-Nvidia ZhiyuLi-Nvidia force-pushed the zhiyul/data_plane_plan branch 5 times, most recently from 1596562 to abada7e Compare May 9, 2026 03:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants