Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions roll/utils/functionals.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from roll.distributed.scheduler.protocol import DataProto
import enum
import traceback
import heapq
from typing import Dict, List, Optional, Tuple, Union
import inspect
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -21,6 +15,9 @@
from roll.utils.kl_controller import AdaptiveKLController
from roll.utils.logging import get_logger

if TYPE_CHECKING:
from roll.distributed.scheduler.protocol import DataProto

logger = get_logger()


Expand Down Expand Up @@ -225,8 +222,14 @@ def entropy_from_logits(logits: torch.Tensor):
return entropy


def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str, batch_num_tokens: int = None,
global_valid_samples: int = None, weights: Optional[torch.Tensor] = None):
def agg_loss(
loss_mat: torch.Tensor,
loss_mask: torch.Tensor,
loss_agg_mode: str,
batch_num_tokens: int = None,
global_valid_samples: int = None,
weights: Optional[torch.Tensor] = None,
):
"""
ref: https://github.com/volcengine/verl/blob/78532923368aeb058f62201489546d013df47710/verl/trainer/ppo/core_algos.py#L370
Aggregate the loss matrix into a scalar.
Expand All @@ -251,8 +254,9 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str
global_valid_samples = loss_mat.size(0)
if loss_agg_mode == "token-mean":
if weights is None:
weights = torch.ones(loss_mask.shape[0], device=loss_mask.device)
loss = (loss_mat * weights.unsqueeze(-1)).sum() / batch_num_tokens
weights = torch.ones(loss_mask.shape[0], dtype=loss_mat.dtype, device=loss_mask.device)
masked_loss = torch.where(loss_mask.bool(), loss_mat, torch.zeros_like(loss_mat))
loss = (masked_loss * weights.unsqueeze(-1)).sum() / batch_num_tokens
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = masked_sum(loss_mat, loss_mask, dim=-1) # token-sum
valid_samples = torch.any(loss_mask > 0, dim=-1).float()
Expand Down Expand Up @@ -288,6 +292,7 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> to
else:
return (tensor * mask).sum() / (mask.sum() + 1e-8)


def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> torch.Tensor:
if dim is not None:
mask_sum = mask.sum(axis=dim)
Expand Down Expand Up @@ -445,6 +450,7 @@ def _parse_aggregation_func(metric_name: str):

return metrics


def reduce_metrics_list(metrics_list: list, reduce_func=np.mean) -> dict:
if len(metrics_list) == 0:
return {}
Expand Down Expand Up @@ -574,7 +580,7 @@ def reward_norm(
reward_mean = reshape_reward.mean(dim=-1, keepdim=True)
elif norm_mean_type == "running":
reward_mean = running.mean
elif norm_mean_type == None:
elif norm_mean_type is None:
reward_mean = 0.0
# 标准差计算
if norm_std_type == "batch":
Expand All @@ -589,7 +595,7 @@ def reward_norm(
if norm_std_type is not None:
normalized_rewards = (rewards - reward_mean) / (reward_std + 1e-6)
else:
normalized_rewards = (rewards - reward_mean)
normalized_rewards = rewards - reward_mean

# 如果是对 group mean 归一化,需要恢复原始形状
if norm_mean_type == "group":
Expand Down Expand Up @@ -653,12 +659,12 @@ def reward_postprocess(data: "DataProto", pipeline_config: RLVRConfig, running_c
pipeline_config.norm_mean_type, pipeline_config.norm_std_type = "group", "group"

response_level_rewards = reward_norm(
response_level_rewards,
n_sample=pipeline_config.actor_infer.generating_args.num_return_sequences,
running_ctrl=running_ctrl,
norm_mean_type=pipeline_config.norm_mean_type,
norm_std_type=pipeline_config.norm_std_type
)
response_level_rewards,
n_sample=pipeline_config.actor_infer.generating_args.num_return_sequences,
running_ctrl=running_ctrl,
norm_mean_type=pipeline_config.norm_mean_type,
norm_std_type=pipeline_config.norm_std_type,
)

# 对reward进行clip
if pipeline_config.reward_clip:
Expand Down Expand Up @@ -798,7 +804,9 @@ def compute_advantage(
kld = None
if is_pure_opd or use_opd:
kld = compute_approx_kl(
log_probs=data.batch["old_log_probs"] if getattr(pipeline_config, "enable_old_logprobs_recompute", False) else data.batch["infer_logprobs"],
log_probs=data.batch["old_log_probs"]
if getattr(pipeline_config, "enable_old_logprobs_recompute", False)
else data.batch["infer_logprobs"],
log_probs_base=data.batch["ref_log_probs"],
action_mask=response_mask,
kl_penalty=getattr(pipeline_config, "kl_penalty", "kl"),
Expand Down Expand Up @@ -848,6 +856,7 @@ def compute_advantage(
data.batch["returns"] = returns
return data


def postprocess_generate(
prompts: "DataProto",
output: torch.Tensor,
Expand Down Expand Up @@ -991,6 +1000,7 @@ def separate_prompt_response(
response_ids = torch.where(response_mask_valid, input_ids, torch.full_like(input_ids, pad_id))
return prompt_ids, response_ids


def filter_func_args(func, forward_args):
signature = inspect.signature(func)
forward_params = signature.parameters.keys()
Expand Down Expand Up @@ -1119,9 +1129,9 @@ def adjust_sequence_length(sequence, target_length, origin_seq_len, pad_value=0)
return sequence[tuple(slices)]


def get_seqlen_balanced_partitions(seqlen_list: List[float],
k_partitions: int,
equal_size: bool = False) -> List[List[int]]:
def get_seqlen_balanced_partitions(
seqlen_list: List[float], k_partitions: int, equal_size: bool = False
) -> List[List[int]]:
"""
Reference: https://github.com/volcengine/verl/blob/468adf22c43b744348051fccd7a5d830c6c3c36a/verl/utils/seqlen_balancing.py

Expand Down Expand Up @@ -1193,16 +1203,14 @@ def __lt__(self, other):
return self.spread > other.spread
return self.sets[0] > other.sets[0]

assert len(seqlen_list) >= k_partitions, \
f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"
assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"

# Sort by sequence length
sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)])
states_pq = []

if equal_size:
assert len(seqlen_list) % k_partitions == 0, \
f"{len(seqlen_list)} % {k_partitions} != 0"
assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0"
for offset in range(0, len(sorted_seqlen_list), k_partitions):
items = []
for i in range(k_partitions):
Expand Down Expand Up @@ -1262,7 +1270,7 @@ def log_seqlen_unbalance(seqlen_list: list[int], partitions: list[list[int]], pr

# Iterate over each batch of sequence lengths
for offset in range(0, len(seqlen_list), batch_size):
cur_sum_seqlen = sum(seqlen_list[offset: offset + batch_size])
cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size])
if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:
min_sum_seqlen = cur_sum_seqlen
if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen:
Expand Down Expand Up @@ -1305,16 +1313,14 @@ def calculate_workload(seq_len_list):
global_partition_lst = [[] for _ in range(world_size)]
for i in range(minibatch_num):
rearrange_minibatch_lst = get_seqlen_balanced_partitions(
workload_lst[i * minibatch_size: (i + 1) * minibatch_size],
workload_lst[i * minibatch_size : (i + 1) * minibatch_size],
k_partitions=world_size,
equal_size=True,
)
for j, part in enumerate(rearrange_minibatch_lst):
global_partition_lst[j].extend([x + minibatch_size * i for x in part])
else:
global_partition_lst = get_seqlen_balanced_partitions(
workload_lst, k_partitions=world_size, equal_size=True
)
global_partition_lst = get_seqlen_balanced_partitions(workload_lst, k_partitions=world_size, equal_size=True)
# Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.
for idx, partition in enumerate(global_partition_lst):
partition.sort(key=lambda x: (workload_lst[x], x))
Expand All @@ -1329,4 +1335,3 @@ def calculate_workload(seq_len_list):
metrics = {}
metrics.update(global_balance_stats)
return metrics

27 changes: 26 additions & 1 deletion tests/utils/test_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import pytest
import torch

from roll.utils.functionals import traverse_obj, divide_by_chunk_size, pad_to_length
from roll.utils.functionals import (
agg_loss,
divide_by_chunk_size,
pad_to_length,
traverse_obj,
)


def visitor(obj: object, path: Tuple):
Expand All @@ -22,6 +27,7 @@ def __init__(self):
"nested_key1": torch.tensor([[1, 2], [3, 4]]),
"nested_key2": [torch.tensor(5), np.array([6, 7])],
}

class CustomObject:
def __init__(self):
self.attr1 = torch.tensor([1, 2, 3])
Expand Down Expand Up @@ -55,5 +61,24 @@ def test_pad_to_length():
print(padded_tensor)


def test_agg_loss_token_mean_masks_loss_numerator_and_gradients():
loss_mat = torch.tensor([[1.0, 100.0], [1.0, -50.0]], requires_grad=True)
loss_mask = torch.tensor([[1, 0], [1, 0]])

loss = agg_loss(
loss_mat=loss_mat,
loss_mask=loss_mask,
loss_agg_mode="token-mean",
batch_num_tokens=int(loss_mask.sum().item()),
)
loss.backward()

assert loss.item() == pytest.approx(1.0)
torch.testing.assert_close(
loss_mat.grad,
torch.tensor([[0.5, 0.0], [0.5, 0.0]]),
)


if __name__ == "__main__":
pytest.main()