|
| 1 | +from typing import TYPE_CHECKING |
| 2 | + |
| 3 | +import torch |
| 4 | +from pydantic import BaseModel, ConfigDict |
| 5 | + |
| 6 | +from art import dev |
| 7 | +from art.utils.group_aggregate import group_aggregate |
| 8 | + |
| 9 | +if TYPE_CHECKING: |
| 10 | + from art.unsloth.service import TrainInputs |
| 11 | + |
| 12 | + |
| 13 | +class Loss(BaseModel): |
| 14 | + model_config = ConfigDict(arbitrary_types_allowed=True) |
| 15 | + mean_policy_loss: torch.Tensor |
| 16 | + mean_kl: torch.Tensor |
| 17 | + mean_entropy: torch.Tensor | None |
| 18 | + probs_corr: torch.Tensor |
| 19 | + |
| 20 | + |
| 21 | +def loss_fn( |
| 22 | + inputs: "TrainInputs", |
| 23 | + new_logprobs: torch.Tensor, |
| 24 | + ref_logprobs: torch.Tensor | None, |
| 25 | + entropies: torch.Tensor | None, |
| 26 | + experimental_config: dev.TrainConfig, |
| 27 | +) -> Loss: |
| 28 | + old_logprobs = shift_tensor(inputs["logprobs"], float("nan")) |
| 29 | + advantages = shift_tensor(inputs["advantages"], 0.0) |
| 30 | + assistant_mask = shift_tensor(inputs["assistant_mask"], False).to( |
| 31 | + new_logprobs.dtype |
| 32 | + ) |
| 33 | + weights = shift_tensor(inputs["weights"], 0.0) |
| 34 | + old_logprobs_mask = ~torch.isnan(old_logprobs) |
| 35 | + probs_corr = torch.corrcoef( |
| 36 | + torch.stack( |
| 37 | + [ |
| 38 | + torch.exp(old_logprobs[old_logprobs_mask]), |
| 39 | + torch.exp(new_logprobs[old_logprobs_mask]), |
| 40 | + ] |
| 41 | + ) |
| 42 | + )[0, 1] |
| 43 | + # Assume missing old logprobs were sampled under the current policy |
| 44 | + old_logprobs = torch.where( |
| 45 | + torch.isnan(old_logprobs), |
| 46 | + new_logprobs.detach(), |
| 47 | + old_logprobs, |
| 48 | + ) |
| 49 | + logprob_diff = new_logprobs - old_logprobs |
| 50 | + importance_sampling_level = experimental_config.get( |
| 51 | + "importance_sampling_level", "token" |
| 52 | + ) |
| 53 | + prob_ratio = torch.exp(logprob_diff) |
| 54 | + if importance_sampling_level != "token": |
| 55 | + sequence_prob_ratio = torch.exp( |
| 56 | + group_aggregate( |
| 57 | + logprob_diff, |
| 58 | + by=shift_tensor(inputs["group_ids"], 0) * assistant_mask, |
| 59 | + reduce="mean", |
| 60 | + ) |
| 61 | + ) |
| 62 | + if importance_sampling_level == "sequence": |
| 63 | + prob_ratio = sequence_prob_ratio |
| 64 | + elif importance_sampling_level == "average": |
| 65 | + prob_ratio = (prob_ratio + sequence_prob_ratio) / 2 |
| 66 | + elif importance_sampling_level == "geometric_average": |
| 67 | + prob_ratio = (prob_ratio**0.5) * (sequence_prob_ratio**0.5) |
| 68 | + epsilon = experimental_config.get("epsilon", 0.2) |
| 69 | + epsilon_high = experimental_config.get("epsilon_high", epsilon) |
| 70 | + if epsilon_high is None: |
| 71 | + epsilon_high = epsilon |
| 72 | + if max_negative_advantage_importance_sampling_weight := experimental_config.get( |
| 73 | + "max_negative_advantage_importance_sampling_weight", None |
| 74 | + ): |
| 75 | + prob_ratio = torch.clamp( |
| 76 | + prob_ratio, max=max_negative_advantage_importance_sampling_weight |
| 77 | + ) |
| 78 | + if experimental_config.get("ppo", True): |
| 79 | + policy_loss = -torch.min( |
| 80 | + prob_ratio * advantages, |
| 81 | + torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages, |
| 82 | + ) |
| 83 | + else: |
| 84 | + # Modified REINFORCE or Clipped IS-weight Policy Optimization (CISPO) |
| 85 | + policy_loss = -( |
| 86 | + torch.clip(prob_ratio.detach(), 1 - epsilon, 1 + epsilon_high) |
| 87 | + * advantages |
| 88 | + * new_logprobs |
| 89 | + ) |
| 90 | + if upper_bound := experimental_config.get("truncated_importance_sampling", None): |
| 91 | + if "original_logprobs" in inputs: |
| 92 | + original_logprobs = shift_tensor(inputs["original_logprobs"], 0.0) |
| 93 | + original_logprobs = torch.where( |
| 94 | + torch.isnan(original_logprobs), |
| 95 | + new_logprobs.detach(), |
| 96 | + original_logprobs, |
| 97 | + ) |
| 98 | + logprob_diff = old_logprobs - original_logprobs |
| 99 | + prob_ratio = torch.exp(logprob_diff) |
| 100 | + policy_loss *= torch.clamp(prob_ratio, max=upper_bound).detach() |
| 101 | + if ref_logprobs is not None: |
| 102 | + kl_div = ( |
| 103 | + torch.exp(ref_logprobs - new_logprobs) - (ref_logprobs - new_logprobs) - 1.0 |
| 104 | + ) |
| 105 | + else: |
| 106 | + kl_div = torch.zeros_like(policy_loss) |
| 107 | + policy_loss = policy_loss * weights * assistant_mask |
| 108 | + kl_div = kl_div * weights * assistant_mask |
| 109 | + mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6) |
| 110 | + mean_kl = kl_div.sum() / (assistant_mask.sum() + 1e-6) |
| 111 | + # Compute mean entropy for the current step |
| 112 | + if entropies is not None: |
| 113 | + shifted_entropies = shift_tensor(entropies, 0.0) |
| 114 | + mean_entropy = (shifted_entropies * weights * assistant_mask).sum() / ( |
| 115 | + assistant_mask.sum() + 1e-6 |
| 116 | + ) |
| 117 | + else: |
| 118 | + mean_entropy = None |
| 119 | + return Loss( |
| 120 | + mean_policy_loss=mean_policy_loss, |
| 121 | + mean_kl=mean_kl, |
| 122 | + mean_entropy=mean_entropy, |
| 123 | + probs_corr=probs_corr, |
| 124 | + ) |
| 125 | + |
| 126 | + |
| 127 | +def shift_tensor(tensor: torch.Tensor, pad: int | float | bool) -> torch.Tensor: |
| 128 | + return torch.nn.functional.pad(tensor[:, 1:], (0, 1), value=pad) |
0 commit comments