Skip to content

Commit b14ce3e

Browse files
committed
feat: Implement loss function and shift_tensor utility for training process
1 parent f2415e5 commit b14ce3e

File tree

3 files changed

+141
-77
lines changed

3 files changed

+141
-77
lines changed

src/art/loss.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from typing import TYPE_CHECKING
2+
3+
import torch
4+
from pydantic import BaseModel, ConfigDict
5+
6+
from art import dev
7+
from art.utils.group_aggregate import group_aggregate
8+
9+
if TYPE_CHECKING:
10+
from art.unsloth.service import TrainInputs
11+
12+
13+
class Loss(BaseModel):
14+
model_config = ConfigDict(arbitrary_types_allowed=True)
15+
mean_policy_loss: torch.Tensor
16+
mean_kl: torch.Tensor
17+
mean_entropy: torch.Tensor | None
18+
probs_corr: torch.Tensor
19+
20+
21+
def loss_fn(
22+
inputs: "TrainInputs",
23+
new_logprobs: torch.Tensor,
24+
ref_logprobs: torch.Tensor | None,
25+
entropies: torch.Tensor | None,
26+
experimental_config: dev.TrainConfig,
27+
) -> Loss:
28+
old_logprobs = shift_tensor(inputs["logprobs"], float("nan"))
29+
advantages = shift_tensor(inputs["advantages"], 0.0)
30+
assistant_mask = shift_tensor(inputs["assistant_mask"], False).to(
31+
new_logprobs.dtype
32+
)
33+
weights = shift_tensor(inputs["weights"], 0.0)
34+
old_logprobs_mask = ~torch.isnan(old_logprobs)
35+
probs_corr = torch.corrcoef(
36+
torch.stack(
37+
[
38+
torch.exp(old_logprobs[old_logprobs_mask]),
39+
torch.exp(new_logprobs[old_logprobs_mask]),
40+
]
41+
)
42+
)[0, 1]
43+
# Assume missing old logprobs were sampled under the current policy
44+
old_logprobs = torch.where(
45+
torch.isnan(old_logprobs),
46+
new_logprobs.detach(),
47+
old_logprobs,
48+
)
49+
logprob_diff = new_logprobs - old_logprobs
50+
importance_sampling_level = experimental_config.get(
51+
"importance_sampling_level", "token"
52+
)
53+
prob_ratio = torch.exp(logprob_diff)
54+
if importance_sampling_level != "token":
55+
sequence_prob_ratio = torch.exp(
56+
group_aggregate(
57+
logprob_diff,
58+
by=shift_tensor(inputs["group_ids"], 0) * assistant_mask,
59+
reduce="mean",
60+
)
61+
)
62+
if importance_sampling_level == "sequence":
63+
prob_ratio = sequence_prob_ratio
64+
elif importance_sampling_level == "average":
65+
prob_ratio = (prob_ratio + sequence_prob_ratio) / 2
66+
elif importance_sampling_level == "geometric_average":
67+
prob_ratio = (prob_ratio**0.5) * (sequence_prob_ratio**0.5)
68+
epsilon = experimental_config.get("epsilon", 0.2)
69+
epsilon_high = experimental_config.get("epsilon_high", epsilon)
70+
if epsilon_high is None:
71+
epsilon_high = epsilon
72+
if max_negative_advantage_importance_sampling_weight := experimental_config.get(
73+
"max_negative_advantage_importance_sampling_weight", None
74+
):
75+
prob_ratio = torch.clamp(
76+
prob_ratio, max=max_negative_advantage_importance_sampling_weight
77+
)
78+
if experimental_config.get("ppo", True):
79+
policy_loss = -torch.min(
80+
prob_ratio * advantages,
81+
torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages,
82+
)
83+
else:
84+
# Modified REINFORCE or Clipped IS-weight Policy Optimization (CISPO)
85+
policy_loss = -(
86+
torch.clip(prob_ratio.detach(), 1 - epsilon, 1 + epsilon_high)
87+
* advantages
88+
* new_logprobs
89+
)
90+
if upper_bound := experimental_config.get("truncated_importance_sampling", None):
91+
if "original_logprobs" in inputs:
92+
original_logprobs = shift_tensor(inputs["original_logprobs"], 0.0)
93+
original_logprobs = torch.where(
94+
torch.isnan(original_logprobs),
95+
new_logprobs.detach(),
96+
original_logprobs,
97+
)
98+
logprob_diff = old_logprobs - original_logprobs
99+
prob_ratio = torch.exp(logprob_diff)
100+
policy_loss *= torch.clamp(prob_ratio, max=upper_bound).detach()
101+
if ref_logprobs is not None:
102+
kl_div = (
103+
torch.exp(ref_logprobs - new_logprobs) - (ref_logprobs - new_logprobs) - 1.0
104+
)
105+
else:
106+
kl_div = torch.zeros_like(policy_loss)
107+
policy_loss = policy_loss * weights * assistant_mask
108+
kl_div = kl_div * weights * assistant_mask
109+
mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6)
110+
mean_kl = kl_div.sum() / (assistant_mask.sum() + 1e-6)
111+
# Compute mean entropy for the current step
112+
if entropies is not None:
113+
shifted_entropies = shift_tensor(entropies, 0.0)
114+
mean_entropy = (shifted_entropies * weights * assistant_mask).sum() / (
115+
assistant_mask.sum() + 1e-6
116+
)
117+
else:
118+
mean_entropy = None
119+
return Loss(
120+
mean_policy_loss=mean_policy_loss,
121+
mean_kl=mean_kl,
122+
mean_entropy=mean_entropy,
123+
probs_corr=probs_corr,
124+
)
125+
126+
127+
def shift_tensor(tensor: torch.Tensor, pad: int | float | bool) -> torch.Tensor:
128+
return torch.nn.functional.pad(tensor[:, 1:], (0, 1), value=pad)

src/art/torchtune/recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def _loss_step(
707707
import torch
708708
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
709709

710-
from ..unsloth.train import shift_tensor
710+
from ..loss import shift_tensor
711711

712712
def make_block_mask(
713713
group_ids: torch.Tensor, # [B, S] int32/64

src/art/unsloth/train.py

Lines changed: 12 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from trl import GRPOTrainer
1212

1313
from .. import dev
14+
from ..loss import loss_fn, shift_tensor
1415
from ..types import TrainConfig
15-
from ..utils.group_aggregate import group_aggregate
1616

1717
if TYPE_CHECKING:
1818
from .service import TrainInputs
@@ -156,81 +156,21 @@ def compute_loss(
156156
ref_logprobs = None
157157
del attn_bias
158158

159-
# Shift inputs for loss calculation
160-
old_logprobs = shift_tensor(inputs["logprobs"], 0.0)
161-
advantages = shift_tensor(inputs["advantages"], 0.0)
162-
assistant_mask = shift_tensor(inputs["assistant_mask"], False).to(
163-
new_logprobs.dtype
164-
)
165-
weights = shift_tensor(inputs["weights"], 0.0)
166-
# Assume missing old logprobs were sampled under the current policy
167-
old_logprobs = torch.where(
168-
torch.isnan(old_logprobs),
169-
new_logprobs.detach(),
170-
old_logprobs,
171-
)
172-
logprob_diff = new_logprobs - old_logprobs
173-
if _config.get("importance_sampling_level", "token") == "sequence":
174-
prob_ratio = torch.exp(
175-
group_aggregate(
176-
logprob_diff,
177-
by=shift_tensor(inputs["group_ids"], 0) * assistant_mask,
178-
reduce="mean",
179-
)
180-
)
181-
else:
182-
prob_ratio = torch.exp(logprob_diff)
183-
epsilon = _config.get("epsilon", 0.2)
184-
epsilon_high = _config.get("epsilon_high", epsilon)
185-
if epsilon_high is None:
186-
epsilon_high = epsilon
187-
if max_negative_advantage_importance_sampling_weight := _config.get(
188-
"max_negative_advantage_importance_sampling_weight", None
189-
):
190-
prob_ratio = torch.clamp(
191-
prob_ratio, max=max_negative_advantage_importance_sampling_weight
192-
)
193-
policy_loss = -torch.min(
194-
prob_ratio * advantages,
195-
torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages,
196-
)
197-
if upper_bound := _config.get("truncated_importance_sampling", None):
198-
if "original_logprobs" in inputs:
199-
original_logprobs = shift_tensor(inputs["original_logprobs"], 0.0)
200-
original_logprobs = torch.where(
201-
torch.isnan(original_logprobs),
202-
new_logprobs.detach(),
203-
original_logprobs,
204-
)
205-
logprob_diff = old_logprobs - original_logprobs
206-
prob_ratio = torch.exp(logprob_diff)
207-
policy_loss *= torch.clamp(prob_ratio, max=upper_bound).detach()
208-
if ref_logprobs is not None:
209-
kl_div = (
210-
torch.exp(ref_logprobs - new_logprobs)
211-
- (ref_logprobs - new_logprobs)
212-
- 1.0
213-
)
214-
else:
215-
kl_div = torch.zeros_like(policy_loss)
216-
217-
policy_loss = policy_loss * weights * assistant_mask
218-
kl_div = kl_div * weights * assistant_mask
219-
mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6)
220-
mean_kl = kl_div.sum() / (assistant_mask.sum() + 1e-6)
221-
222-
# Compute mean entropy for the current step
223-
shifted_entropies = shift_tensor(entropies, 0.0)
224-
mean_entropy = (shifted_entropies * weights * assistant_mask).sum() / (
225-
assistant_mask.sum() + 1e-6
159+
loss = loss_fn(
160+
inputs,
161+
new_logprobs,
162+
ref_logprobs,
163+
entropies,
164+
_config,
226165
)
227166

228167
trainer._metrics["train"]["learning_rate"].append(config.learning_rate)
229-
trainer._metrics["train"]["policy_loss"].append(mean_policy_loss.item())
230-
trainer._metrics["train"]["entropy"].append(mean_entropy.item()) # type: ignore
168+
trainer._metrics["train"]["policy_loss"].append(loss.mean_policy_loss.item())
169+
if loss.mean_entropy is not None:
170+
trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) # type: ignore
231171
if config.beta > 0.0:
232-
trainer._metrics["train"]["kl_div"].append(mean_kl.item())
233-
return mean_policy_loss + config.beta * mean_kl
172+
trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item())
173+
return loss.mean_policy_loss + config.beta * loss.mean_kl # type: ignore
234174

235175
return compute_loss
236176

@@ -395,9 +335,5 @@ def _calculate_logprobs(
395335
return log_probs, entropy
396336

397337

398-
def shift_tensor(tensor: torch.Tensor, pad: int | float | bool) -> torch.Tensor:
399-
return torch.nn.functional.pad(tensor[:, 1:], (0, 1), value=pad)
400-
401-
402338
def gc_and_empty_cuda_cache(n: int = 3) -> None:
403339
[gc.collect() >= 0 and torch.cuda.empty_cache() for _ in range(n)]

0 commit comments

Comments
 (0)