Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions roll/utils/functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

if TYPE_CHECKING:
from roll.distributed.scheduler.protocol import DataProto
import enum
import traceback

import heapq
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -225,8 +224,14 @@ def entropy_from_logits(logits: torch.Tensor):
return entropy


def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str, batch_num_tokens: int = None,
global_valid_samples: int = None, weights: Optional[torch.Tensor] = None):
def agg_loss(
loss_mat: torch.Tensor,
loss_mask: torch.Tensor,
loss_agg_mode: str,
batch_num_tokens: int = None,
global_valid_samples: int = None,
weights: Optional[torch.Tensor] = None,
):
"""
ref: https://github.com/volcengine/verl/blob/78532923368aeb058f62201489546d013df47710/verl/trainer/ppo/core_algos.py#L370
Aggregate the loss matrix into a scalar.
Expand Down Expand Up @@ -288,6 +293,7 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> to
else:
return (tensor * mask).sum() / (mask.sum() + 1e-8)


def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> torch.Tensor:
if dim is not None:
mask_sum = mask.sum(axis=dim)
Expand Down Expand Up @@ -445,6 +451,7 @@ def _parse_aggregation_func(metric_name: str):

return metrics


def reduce_metrics_list(metrics_list: list, reduce_func=np.mean) -> dict:
if len(metrics_list) == 0:
return {}
Expand Down Expand Up @@ -574,7 +581,7 @@ def reward_norm(
reward_mean = reshape_reward.mean(dim=-1, keepdim=True)
elif norm_mean_type == "running":
reward_mean = running.mean
elif norm_mean_type == None:
elif norm_mean_type is None:
reward_mean = 0.0
# 标准差计算
if norm_std_type == "batch":
Expand All @@ -589,7 +596,7 @@ def reward_norm(
if norm_std_type is not None:
normalized_rewards = (rewards - reward_mean) / (reward_std + 1e-6)
else:
normalized_rewards = (rewards - reward_mean)
normalized_rewards = rewards - reward_mean

# 如果是对 group mean 归一化,需要恢复原始形状
if norm_mean_type == "group":
Expand Down Expand Up @@ -653,12 +660,12 @@ def reward_postprocess(data: "DataProto", pipeline_config: RLVRConfig, running_c
pipeline_config.norm_mean_type, pipeline_config.norm_std_type = "group", "group"

response_level_rewards = reward_norm(
response_level_rewards,
n_sample=pipeline_config.actor_infer.generating_args.num_return_sequences,
running_ctrl=running_ctrl,
norm_mean_type=pipeline_config.norm_mean_type,
norm_std_type=pipeline_config.norm_std_type
)
response_level_rewards,
n_sample=pipeline_config.actor_infer.generating_args.num_return_sequences,
running_ctrl=running_ctrl,
norm_mean_type=pipeline_config.norm_mean_type,
norm_std_type=pipeline_config.norm_std_type,
)

# 对reward进行clip
if pipeline_config.reward_clip:
Expand Down Expand Up @@ -798,7 +805,9 @@ def compute_advantage(
kld = None
if is_pure_opd or use_opd:
kld = compute_approx_kl(
log_probs=data.batch["old_log_probs"] if getattr(pipeline_config, "enable_old_logprobs_recompute", False) else data.batch["infer_logprobs"],
log_probs=data.batch["old_log_probs"]
if getattr(pipeline_config, "enable_old_logprobs_recompute", False)
else data.batch["infer_logprobs"],
log_probs_base=data.batch["ref_log_probs"],
action_mask=response_mask,
kl_penalty=getattr(pipeline_config, "kl_penalty", "kl"),
Expand All @@ -817,7 +826,8 @@ def compute_advantage(
data.batch["token_level_rewards"] = token_level_rewards
if adv_estimator == "gae":
values = data.batch["values"].float()
data.batch["values"] = values * response_mask
values = values * response_mask
data.batch["values"] = values
advantages, returns = compute_gae_advantage_return(
token_level_rewards=token_level_rewards, values=values, gamma=gamma, lambd=lambd
)
Expand Down Expand Up @@ -848,6 +858,7 @@ def compute_advantage(
data.batch["returns"] = returns
return data


def postprocess_generate(
prompts: "DataProto",
output: torch.Tensor,
Expand Down Expand Up @@ -991,6 +1002,7 @@ def separate_prompt_response(
response_ids = torch.where(response_mask_valid, input_ids, torch.full_like(input_ids, pad_id))
return prompt_ids, response_ids


def filter_func_args(func, forward_args):
signature = inspect.signature(func)
forward_params = signature.parameters.keys()
Expand Down Expand Up @@ -1119,9 +1131,9 @@ def adjust_sequence_length(sequence, target_length, origin_seq_len, pad_value=0)
return sequence[tuple(slices)]


def get_seqlen_balanced_partitions(seqlen_list: List[float],
k_partitions: int,
equal_size: bool = False) -> List[List[int]]:
def get_seqlen_balanced_partitions(
seqlen_list: List[float], k_partitions: int, equal_size: bool = False
) -> List[List[int]]:
"""
Reference: https://github.com/volcengine/verl/blob/468adf22c43b744348051fccd7a5d830c6c3c36a/verl/utils/seqlen_balancing.py

Expand Down Expand Up @@ -1193,16 +1205,14 @@ def __lt__(self, other):
return self.spread > other.spread
return self.sets[0] > other.sets[0]

assert len(seqlen_list) >= k_partitions, \
f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"
assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"

# Sort by sequence length
sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)])
states_pq = []

if equal_size:
assert len(seqlen_list) % k_partitions == 0, \
f"{len(seqlen_list)} % {k_partitions} != 0"
assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0"
for offset in range(0, len(sorted_seqlen_list), k_partitions):
items = []
for i in range(k_partitions):
Expand Down Expand Up @@ -1262,7 +1272,7 @@ def log_seqlen_unbalance(seqlen_list: list[int], partitions: list[list[int]], pr

# Iterate over each batch of sequence lengths
for offset in range(0, len(seqlen_list), batch_size):
cur_sum_seqlen = sum(seqlen_list[offset: offset + batch_size])
cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size])
if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:
min_sum_seqlen = cur_sum_seqlen
if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen:
Expand Down Expand Up @@ -1305,16 +1315,14 @@ def calculate_workload(seq_len_list):
global_partition_lst = [[] for _ in range(world_size)]
for i in range(minibatch_num):
rearrange_minibatch_lst = get_seqlen_balanced_partitions(
workload_lst[i * minibatch_size: (i + 1) * minibatch_size],
workload_lst[i * minibatch_size : (i + 1) * minibatch_size],
k_partitions=world_size,
equal_size=True,
)
for j, part in enumerate(rearrange_minibatch_lst):
global_partition_lst[j].extend([x + minibatch_size * i for x in part])
else:
global_partition_lst = get_seqlen_balanced_partitions(
workload_lst, k_partitions=world_size, equal_size=True
)
global_partition_lst = get_seqlen_balanced_partitions(workload_lst, k_partitions=world_size, equal_size=True)
# Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.
for idx, partition in enumerate(global_partition_lst):
partition.sort(key=lambda x: (workload_lst[x], x))
Expand All @@ -1329,4 +1337,3 @@ def calculate_workload(seq_len_list):
metrics = {}
metrics.update(global_balance_stats)
return metrics

40 changes: 39 additions & 1 deletion tests/utils/test_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
import numpy as np
import pytest
import torch
from tensordict import TensorDict

from roll.utils.functionals import traverse_obj, divide_by_chunk_size, pad_to_length
from roll.distributed.scheduler.protocol import DataProto
from roll.utils.functionals import (
compute_advantage,
divide_by_chunk_size,
pad_to_length,
traverse_obj,
)


def visitor(obj: object, path: Tuple):
Expand All @@ -22,6 +29,7 @@ def __init__(self):
"nested_key1": torch.tensor([[1, 2], [3, 4]]),
"nested_key2": [torch.tensor(5), np.array([6, 7])],
}

class CustomObject:
def __init__(self):
self.attr1 = torch.tensor([1, 2, 3])
Expand Down Expand Up @@ -55,5 +63,35 @@ def test_pad_to_length():
print(padded_tensor)


def test_compute_advantage_masks_values_before_gae_bootstrap():
response_mask = torch.tensor([[1.0, 1.0, 0.0]])
token_level_rewards = torch.tensor([[0.0, 1.0, 0.0]])
values = torch.tensor([[0.0, 0.0, 100.0]])
data = DataProto(
batch=TensorDict(
{
"response_mask": response_mask.clone(),
"token_level_rewards": token_level_rewards.clone(),
"values": values.clone(),
},
batch_size=[1],
),
meta_info={},
)

compute_advantage(
data=data,
gamma=torch.tensor(1.0),
lambd=torch.tensor(0.95),
adv_estimator="gae",
response_mask=response_mask,
)

expected = torch.tensor([[0.95, 1.0, 0.0]])
torch.testing.assert_close(data.batch["values"], values * response_mask)
torch.testing.assert_close(data.batch["advantages"], expected)
torch.testing.assert_close(data.batch["returns"], expected)


if __name__ == "__main__":
pytest.main()