Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ jobs:
strategy:
fail-fast: false
matrix:
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
defaults:
run:
working-directory: ${{ github.workspace }}
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@
'tests': [
{'test_file': 'test_megatron_argument_validation.py', 'num_gpus': 0},
{'test_file': 'test_dp_schedule.py', 'num_gpus': 0},
{'test_file': 'test_cp_utils.py', 'num_gpus': 0},
{'test_file': 'test_metric_report.py', 'num_gpus': 0},
{'test_file': 'test_metric_report_dist.py', 'num_gpus': 0},
{'test_file': 'test_loss_cp_invariance.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_rollout_contracts.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_runtime_hook_contracts.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_path_loading_contracts.py', 'num_gpus': 0},
Expand Down
21 changes: 17 additions & 4 deletions examples/multi_agent/agent_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ async def run_agent_system(args, sample):
args = deepcopy(args) # Deep copy args because rollout_with_multi_agents mutates them.
args.sample = sample
args.results_dict = {"solver": [], "rewriter": [], "selector": []}
# Every sample emitted below is a training sample split out of this one
# rollout execution (the input ``sample``). Stamp the shared rollout id on
# every collected sample at each return point so the per-rollout loss
# reducer aggregates the solver / rewriter / selector siblings as one
# rollout instead of N, and the by-rollout step splitter keeps them in
# the same step. Captured here because ``sample`` gets shadowed by zip-
# loop variables further down.
input_rollout_id = sample.index

def _emit(samples_list):
for s in samples_list:
s.rollout_id = input_rollout_id
return samples_list

problem_statement = sample.prompt
tasks = [solver_worker(args, problem_statement, worker_id) for worker_id in range(args.num_parallel)]
Expand All @@ -210,7 +223,7 @@ def reward_adjustment(samples, reward_weight):

if len(previous_solutions) == 0:
reward_adjustment(args.results_dict["solver"], args.incorrect_reward_weight)
return args.results_dict["solver"]
return _emit(args.results_dict["solver"])

# Rewriting
tasks = [
Expand All @@ -232,15 +245,15 @@ def reward_adjustment(samples, reward_weight):
if len(rewrited_solutions) == 0:
reward_adjustment(args.results_dict["solver"], args.incorrect_reward_weight)
reward_adjustment(args.results_dict["rewriter"], args.incorrect_reward_weight)
return args.results_dict["solver"] + args.results_dict["rewriter"]
return _emit(args.results_dict["solver"] + args.results_dict["rewriter"])

# Selection
selector = SelectorAgent()
response = await selector.select(args, problem_statement, rewrited_solutions)
if len(args.results_dict["selector"]) == 0:
reward_adjustment(args.results_dict["solver"], args.incorrect_reward_weight)
reward_adjustment(args.results_dict["rewriter"], args.incorrect_reward_weight)
return args.results_dict["solver"] + args.results_dict["rewriter"]
return _emit(args.results_dict["solver"] + args.results_dict["rewriter"])

assert (
len(args.results_dict["selector"]) == 1
Expand Down Expand Up @@ -269,4 +282,4 @@ def reward_adjustment(samples, reward_weight):
reward_adjustment(args.results_dict["rewriter"], args.incorrect_reward_weight)
reward_adjustment(args.results_dict["selector"], args.incorrect_reward_weight)

return args.results_dict["solver"] + args.results_dict["rewriter"] + args.results_dict["selector"]
return _emit(args.results_dict["solver"] + args.results_dict["rewriter"] + args.results_dict["selector"])
6 changes: 6 additions & 0 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:
rollout_data["loss_masks"] = [
torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in rollout_data["loss_masks"]
]
if "rollout_mask_sums" in rollout_data:
# Promote precomputed per-rollout mask totals to GPU tensors here
# (matching loss_masks) so the loss reducer can just divide.
rollout_data["rollout_mask_sums"] = torch.tensor(
rollout_data["rollout_mask_sums"], dtype=torch.float32, device=torch.cuda.current_device()
)
if "multimodal_train_inputs" in rollout_data:
# Move multimodal training tensors to GPU in advance
rollout_data["multimodal_train_inputs"] = [
Expand Down
136 changes: 130 additions & 6 deletions slime/backends/megatron_utils/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,37 @@ def get_sum_of_sample_mean(
total_lengths: list[int],
response_lengths: list[int],
loss_masks: list[torch.Tensor],
sample_denoms: list[torch.Tensor] | torch.Tensor | None = None,
calculate_per_token_loss: bool = False,
qkv_format: str = "thd",
max_seq_lens: list[int] | None = None,
) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Calculate correct sample mean for CP
Calculate correct sample mean for CP.

The default (``sample_denoms=None``) is the legacy per-sample mean: each
sample's denominator is its own ``loss_mask.sum()``. Callers that want a
per-rollout token-weighted mean pass pre-computed per-sample denominators
(already as GPU tensors — see actor side) where every sample in the same
rollout group carries the same value (the sum of that rollout's mask
totals across every sibling sample in the step). Pre-computing at the
step level rather than per-mb is required — otherwise a rollout whose
samples land in different micro-batches would get a partial denominator
on each side.
"""
if sample_denoms is None:
sample_denoms = [m.sum() for m in loss_masks]

cp_size = mpu.get_context_parallel_world_size()
if cp_size == 1:

def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
return sum(
[
(x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1)
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)
(x_i * loss_mask_i).sum() / torch.clamp_min(denom, 1)
for x_i, loss_mask_i, denom in zip(
x.split(response_lengths, dim=0), loss_masks, sample_denoms, strict=False
)
]
)

Expand Down Expand Up @@ -100,9 +116,9 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor:
def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor:
return sum(
[
(x_i * chunked_loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1)
for x_i, chunked_loss_mask, loss_mask in zip(
x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, loss_masks, strict=False
(x_i * chunked_loss_mask).sum() / torch.clamp_min(denom, 1)
for x_i, chunked_loss_mask, denom in zip(
x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, sample_denoms, strict=False
)
]
)
Expand All @@ -120,6 +136,114 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor:
return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token


def reduce_train_step_metrics(
losses_reduced: list[dict],
*,
calculate_per_token_loss: bool,
step_global_batch_size: int,
cp_size: int,
dp_with_cp_group,
) -> dict[str, float]:
"""Aggregate per-mb log dicts into the dict ``train_one_step`` reports.

Pipeline (1:1 with what the train loop used to do inline):
1. Sum each metric's per-mb ``values`` tensor locally on this rank.
2. All-reduce across the DP*CP group (``dp_with_cp_group``).
3. Apply the per-mode divisor / cp_factor:
- per-token-loss: divisor = ``values[0]`` = all-reduced ``num_tokens``,
CP-inflated by ``cp_size`` because every CP rank computes the same
num_tokens off the FULL (not chunked) masks; the
``cp_factor = cp_size`` multiplier cancels that inflation, leaving
the genuine per-token average.
- per-rollout-mean: divisor = constant ``step_global_batch_size`` from
the rollout side, never all-reduced, so no CP inflation to cancel
and ``cp_factor = 1``.

Tests pass a mock ``dp_with_cp_group`` and monkeypatch ``dist.all_reduce``
to a no-op, then pre-aggregate virtual ranks themselves — this exercises
the same call shape as production while staying single-process.
"""
keys = losses_reduced[0]["keys"]
values = None
for x in losses_reduced:
values = x["values"] if values is None else values + x["values"]
assert len(keys) + 1 == values.numel()
dist.all_reduce(values, group=dp_with_cp_group)
values = values.tolist()

if calculate_per_token_loss:
num_samples_or_tokens = values[0]
cp_factor = cp_size
else:
num_samples_or_tokens = step_global_batch_size
cp_factor = 1
return {key: value * cp_factor / num_samples_or_tokens for key, value in zip(keys, values[1:], strict=False)}


def rollout_log_metric_contribution(
per_rank_reducer_sum: float,
*,
cp_size: int,
num_rollouts_in_rollout: int,
dp_size: int,
) -> tuple[float, float]:
"""``(sum, count)`` tuple to hand the gather step for a per-rollout-mean
metric on the rollout side (``log_rollout_data``).

Sum across DP*CP ranks of ``count`` lands on ``num_rollouts_in_rollout``
(``dp_size`` here is the no-CP DP width; the gather covers ``dp_size *
cp_size`` ranks, and each rank emits the same ``count``, so the totals
cancel out the ``cp_size`` in the sum). Result: ``Σsum / Σcount =
sum_DP_full / num_rollouts`` — the same number ``train_one_step`` reports
for the same samples (when ``num_steps_per_rollout == 1``).

Pair with :func:`gather_and_reduce_log_dict` to do the full end-to-end
in tests (single helper call per rank, returns the reduced number on
the source rank).
"""
sum_value = cp_size * per_rank_reducer_sum
count = num_rollouts_in_rollout / dp_size
return sum_value, count


def gather_and_reduce_log_dict(
log_dict: dict,
*,
dp_size: int,
dp_src_rank: int,
dp_group,
) -> dict | None:
"""``dist.gather_object`` per-rank log_dicts + per-key reduction.

Per key in the gathered dicts:
- ``(sum, count)`` tuple → ``Σsum / Σcount`` (per-rollout-mean shape;
pair with :func:`rollout_log_metric_contribution`).
- plain value → ``Σ / dp_size`` (legacy mean-across-ranks; the only
correct answer when ranks hold the same data).

Returns the reduced dict on ``dp_src_rank``, ``None`` elsewhere. The
caller adds whatever metric-name prefix / wandb plumbing it wants —
this helper stays free of side effects so CPU multi-process unit tests
can drive it directly with real ``torch.distributed``.
"""
if dist.get_rank() == dp_src_rank:
gathered = [None] * dp_size
dist.gather_object(log_dict, gathered, dst=dp_src_rank, group=dp_group)
reduced: dict = {}
for key in log_dict:
values = [d[key] for d in gathered]
first = values[0]
if isinstance(first, tuple) and len(first) == 2:
total_sum = sum(v[0] for v in values)
total_count = sum(v[1] for v in values)
reduced[key] = total_sum / total_count if total_count else 0.0
else:
reduced[key] = sum(values) / dp_size
return reduced
dist.gather_object(log_dict, None, dst=dp_src_rank, group=dp_group)
return None


def all_gather_with_cp(tensor: torch.Tensor, total_length: int, response_length: int) -> torch.Tensor:
"""
Gather tensors across all ranks in the context parallel group.
Expand Down
Loading
Loading