Skip to content

[Feature] chunk actor logprob computation for memory saving#1555

Open
tina-wen wants to merge 1 commit intoInternLM:mainfrom
tina-wen:rl_chunk_logprobs
Open

[Feature] chunk actor logprob computation for memory saving#1555
tina-wen wants to merge 1 commit intoInternLM:mainfrom
tina-wen:rl_chunk_logprobs

Conversation

@tina-wen
Copy link

Description

This PR adds chunking along the seq_len dimension when computing actor_logprob.
The computation now processes in chunks iteratively, significantly reducing peak memory usage.

Key Changes

  • Add chunking logic for seq_len dimension in actor_logprob computation
  • Process logprob calculation in chunks to trade compute for memory
  • Configurable via chunk_size in WorkerConfig.loss_cfg (BaseRLLossConfig, default: 1024)

@HAOCHENYE
Copy link
Collaborator

@claude review

output = self._engine.forward_only(seq_ctx=seq_ctx)
old_logprobs = gather_logprobs(output["logits"], shifted_labels)
if self.logprob_chunk_size is not None:
loss_ctx = LogProbContext(chunk_size=self.logprob_chunk_size, shifted_labels=shifted_labels)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add LogProbContextConfig like CELossConfig, loss_ctx should be built by loss_cfg.build

from xtuner.v1.rl.utils import gather_logprobs


class LogProbContext:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: Public class LogProbContext is missing a docstring (Google Style per CLAUDE.md). Should document its purpose, the chunk_size and shifted_labels parameters, and that forward() returns log-probabilities (not a traditional loss).

@claude
Copy link

claude bot commented Mar 17, 2026

Claude: ## Summary

Adds chunked log-probability computation along the seq_len dimension to reduce peak memory during compute_actor_logprobs. New LogProbContext class processes logits in configurable chunks via WorkerConfig.loss_cfg.chunk_size.

Issues

Critical

  • [worker.py:380-383] Fallback path broken when logprob_chunk_size is None: The else branch sets loss_ctx = None, but the code unconditionally reads output["loss"] which will be None. The original gather_logprobs(output["logits"], shifted_labels) call was removed without a replacement for the non-chunked path.

Warning

  • [rl_loss.py:7] LogProbContext does not inherit from BaseLossContext, breaking the type contract expected by LMHead and the model's __call__ signature. Should follow the established Config/Context pattern (as the other reviewer also noted).
  • [rl_loss.py:25] Return value (loss, (None, None)) doesn't match LMHead's return contract — inner tuple second element should be {} not None.
  • [rl_loss.py:19] Variable named loss actually holds log-probabilities — misleading.
  • [train_engine.py:173] Missing type annotation on loss_ctx parameter; should also default to None for backward compatibility.

Nit

  • [rl_loss.py:7] Public class LogProbContext missing Google Style docstring.

Verdict

REQUEST_CHANGES

@tina-wen tina-wen force-pushed the rl_chunk_logprobs branch from 3368beb to 0a5b2d8 Compare March 18, 2026 07:30
mode = "chunk"
else:
mode = "eager"
loss_ctx = LogProbConfig(chunk_size=self.logprob_chunk_size, mode=mode).build(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build方法是不是应该init时调用呢?

@tina-wen tina-wen force-pushed the rl_chunk_logprobs branch 2 times, most recently from 2ff2436 to 947f826 Compare March 19, 2026 15:40
@tina-wen tina-wen force-pushed the rl_chunk_logprobs branch from 947f826 to c82f48f Compare March 19, 2026 15:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants