Skip to content

feat: support staleness-window in ReplayBufferNew#2458

Open
yuki-97 wants to merge 4 commits into
mainfrom
yukih/staleness-sample
Open

feat: support staleness-window in ReplayBufferNew#2458
yuki-97 wants to merge 4 commits into
mainfrom
yukih/staleness-sample

Conversation

@yuki-97
Copy link
Copy Markdown
Contributor

@yuki-97 yuki-97 commented May 11, 2026

Part of RL-727. Stacks on #2448.

Implements ReplayBufferNew, a temporary replacement for ReplayBuffer until TQReplayBuffer is ready.

Motivation: ReplayBuffer.sample() requires target_weight_version == current_weight_version, which stalls training when the exact-match trajectories haven't arrived yet (buffer starvation). ReplayBufferNew fixes this by allowing slightly older trajectories to be used, with an importance-sampling correction.

Changes:

  • Add max_staleness config: trajectories with trainer_version - weight_version > max_staleness are evicted at the start of each sample() call.
  • sample() selects from the staleness window [trainer_version - max_staleness, trainer_version], removing the strict target_weight_version == current_weight_version gate.
  • Add sample_freshest_first flag (default True): when True, selects the highest-version trajectories first; when False, uses FIFO (insertion order).
  • target_weight_versions is intentionally unused in ReplayBufferNew — it gates generation on specific trainer steps, causing generation pauses. Will be removed when cleaning up after TQReplayBuffer lands.
  • Unit tests covering eviction, staleness-window sampling, freshest-first ordering, and FIFO ordering.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 11, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yuki-97 yuki-97 marked this pull request as ready for review May 11, 2026 05:06
@yuki-97 yuki-97 requested review from a team as code owners May 11, 2026 05:06
@yuki-97 yuki-97 requested review from mehraakash and terrykong May 11, 2026 05:06
@yuki-97 yuki-97 force-pushed the yukih/staleness-sample branch from 85e179f to 19334ad Compare May 11, 2026 10:04
# limitations under the License.

import threading as _threading
from collections import Counter
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

sampled_weights
)
sampled_items = [self.trajectories[i] for i in selected]
for idx in sorted(selected, reverse=True):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we refactored into another function:

  def _remove_indices(self, indices: Iterable[int]) -> None:
      for idx in sorted(indices, reverse=True):
          self.trajectory_versions.pop(idx)
          self.target_weight_versions.pop(idx)
          self.trajectories.pop(idx)

can then use it in _evict and sample and provide different Iterables?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea! e743b7d

"""
min_valid = current_weight_version - self.max_staleness
stale = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid]
for idx in sorted(stale, reverse=True):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below on adding it in a function.

stale = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid]
for idx in sorted(stale, reverse=True):
self.trajectory_versions.pop(idx)
self.trajectories.pop(idx)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we want to eventually get rid of target_weight_versions but since we inherited from ReplayBufferImpl that list will be created. So we either keep that state aligned or remove it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also covered in e743b7d

@ray.remote # pragma: no cover
class ReplayBufferNew(ReplayBufferImpl):
pass
"""Staleness-window replay buffer.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a follow-up task here before wiring this in:

ReplayBufferNew removes exact target matching in sample() here, but the collector still enforces
target-version reservation and generation-limit pauses through last_target_weight_already_generated.

For end-to-end staleness-window sampling, the collector needs a mode that generates based on
current generation_weight_version and buffer/backpressure capacity, and not future target_weight_version slots.

We'll control generation using SingleController by:

  • Buffer Capacity
  • Inflight Semaphore
  • Refit pause
  • Any manual Pause
  • Dataloader availability

Copy link
Copy Markdown
Contributor Author

@yuki-97 yuki-97 May 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For end-to-end staleness-window sampling, the collector needs a mode that generates based on
current generation_weight_version and buffer/backpressure capacity, and not future target_weight_version slots.

yes, that's what we will do in SingleController, and this PR is a foundational work of it. when implementing SingleController we can just continue based on this.

currently collector still have generation-limit, and if we want to update it, will need many changes. and given that we will eventually use SingleController to control this, I think there's no need for us to update the collector here.

I think we can get this PR in first (have unit test to guard the new implementation), and it won't be a dead code since we'll definitely use it later. I'm also fine with keep the PR until we implement SingleController. wdyt? @mehraakash @terrykong

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just meant that we should add a note that end-to-end staleness-window sampling will get completed when ...

@yuki-97 yuki-97 force-pushed the yukih/refactor-async-utils branch from 17cd7e2 to 7f30994 Compare May 12, 2026 05:19
@yuki-97 yuki-97 requested a review from a team as a code owner May 12, 2026 05:19
Base automatically changed from yukih/refactor-async-utils to main May 12, 2026 06:15
yuki-97 added 3 commits May 12, 2026 01:38
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
@yuki-97 yuki-97 force-pushed the yukih/staleness-sample branch from c68a0f5 to ced01e3 Compare May 12, 2026 08:40
Signed-off-by: Yuki Huang <yukih@nvidia.com>
@yuki-97 yuki-97 added the CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) label May 12, 2026
@yuki-97
Copy link
Copy Markdown
Contributor Author

yuki-97 commented May 12, 2026

/ok to test e743b7d

Copy link
Copy Markdown

@mehraakash mehraakash left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@saumishr
Copy link
Copy Markdown

@yuki-97 Please add a disclaimer on the new interfaces that its a work in progress and is subject to change. Users should not be using these WIP interfaces.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants