Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def forward(
loss_type="sigmoid",
label_smoothing=0.0,
discopop_tau=0.05,
alpha=1.0,
):
"""
Fused linear layer with DPO loss.
Expand All @@ -185,7 +186,7 @@ def forward(
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
ignore_index (int): Index to ignore in loss computation
beta (float): Weight for the odds ratio loss
beta (float): Weight for the direct preference loss
compute_nll_loss (bool): Whether to compute the NLL loss
compiled (bool): Whether to use torch compile
use_ref_model (bool): Whether to use a reference model
Expand All @@ -194,6 +195,7 @@ def forward(
loss_type (str): Variant of DPO loss to compute.
label_smoothing (float): Label smoothing for "robust" / "exo_pair" / cDPO.
discopop_tau (float): Temperature for the DiscoPOP modulation term.
alpha (float): Weight for the NLL loss component
Returns:
torch.Tensor: Computed loss
"""
Expand All @@ -206,6 +208,7 @@ def forward(
bias=bias,
ignore_index=ignore_index,
beta=beta,
alpha=alpha,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
use_ref_model=use_ref_model,
Expand All @@ -222,7 +225,7 @@ def forward(
@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None, None, None, None, None, None, None
return *grads, None, None, None, None, None, None, None, None, None, None, None, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
Expand All @@ -247,6 +250,7 @@ def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
compute_nll_loss: bool = False,
compiled: bool = True,
use_ref_model: bool = True,
Expand All @@ -259,7 +263,8 @@ def __init__(
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
beta (float): Weight for the direct preference loss.
alpha (float): Weight for the NLL loss component.
compute_nll_loss (bool): Whether to compute the NLL loss.
compiled (bool): Whether to use the torch compiled kernel.
use_ref_model (bool): Whether to use a reference model for the DPO loss.
Expand All @@ -274,6 +279,7 @@ def __init__(
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.use_ref_model = use_ref_model
Expand Down Expand Up @@ -321,4 +327,5 @@ def forward(
self.loss_type,
self.label_smoothing,
self.discopop_tau,
self.alpha,
)
113 changes: 106 additions & 7 deletions test/chunked_loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,8 +869,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
ref_input,
ref_weight1,
ref_bias1,
-100,
0.1,
-100, # ignore_index
0.1, # beta
compute_nll_loss,
)
loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo(
Expand All @@ -881,8 +881,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
ref_input,
ref_weight2,
ref_bias2,
-100,
0.1,
-100, # ignore_index
0.1, # beta
compute_nll_loss,
)

Expand Down Expand Up @@ -1112,14 +1112,14 @@ def test_correctness_functional_apo_loss_types(
ref_input,
ref_weight1,
ref_bias1,
-100,
0.1,
-100, # ignore_index
0.1, # beta
compute_nll_loss,
True, # compiled
True, # use_ref_model
False, # average_log_prob
1, # chunk_size
loss_type, # loss_type
loss_type,
)

# For comparison, create a LigerFusedLinearDPOLoss with the loss_type
Expand Down Expand Up @@ -1315,6 +1315,105 @@ def test_invalid_loss_type():
assert loss_fn.loss_type == loss_type


@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_alpha_scales_nll_loss(dtype):
"""
Verify that alpha is actually forwarded and scales the NLL component.
With compute_nll_loss=True, loss(alpha=2) should differ from loss(alpha=1).
"""
B, T, H, V = 4, 16, 32, 64
atol = 1e-4 if dtype == torch.float32 else 5e-2

_weight = torch.randn(V, H, device=device, dtype=dtype)
_ref_weight = torch.randn(V, H, device=device, dtype=dtype)
_input = torch.randn(B, T, H, device=device, dtype=dtype)
target = torch.randint(0, V, (B, T), device=device, dtype=torch.long)

def run(alpha):
inp = _input.detach().clone().requires_grad_(True)
w = _weight.detach().clone().requires_grad_(True)
rw = _ref_weight.detach().clone().requires_grad_(True)
loss_fn = LigerFusedLinearDPOLoss(
beta=0.1,
alpha=alpha,
compute_nll_loss=True,
use_ref_model=True,
average_log_prob=False,
)
loss, _ = loss_fn(w, inp, target, None, _input.detach(), rw, None)
return loss

loss_alpha1 = run(alpha=1.0)
loss_alpha2 = run(alpha=2.0)

assert not torch.allclose(loss_alpha1, loss_alpha2, atol=atol), (
f"Expected losses to differ when alpha changes, but got {loss_alpha1} vs {loss_alpha2}"
)


def test_functional_positional_arg_contract():
"""
Pin the positional-argument contract of the public functional alias.

`alpha` is appended at the *end* of `LigerFusedLinearDPOFunction.forward` (not
inserted mid-list) precisely so that existing positional `.apply()` /
`liger_fused_linear_dpo` callers don't silently shift every later argument by
one slot. This test exercises the pre-PR positional list (no `alpha`) and
asserts it produces the same result as the keyword-driven `nn.Module` wrapper.
If a future param insertion shifts the positional slots, this diverges.
"""
B, T, H, V = 4, 8, 16, 32
dtype = torch.float32

_input = torch.randn(B, T, H, device=device, dtype=dtype)
target = torch.randint(0, V, (B, T), device=device, dtype=torch.long)
_weight = torch.randn(V, H, device=device, dtype=dtype)
_ref_weight = torch.randn(V, H, device=device, dtype=dtype)
ref_input = torch.randn(B, T, H, device=device, dtype=dtype)

# Pre-PR positional list: the args after `beta` are
# compute_nll_loss, compiled, use_ref_model, average_log_prob, chunk_size, loss_type.
loss_positional, _ = liger_fused_linear_dpo(
_input.detach().clone().requires_grad_(True),
_weight.detach().clone().requires_grad_(True),
target,
None, # bias
ref_input,
_ref_weight.detach().clone().requires_grad_(True),
None, # ref_bias
-100, # ignore_index
0.1, # beta
True, # compute_nll_loss
True, # compiled
True, # use_ref_model
False, # average_log_prob
1, # chunk_size
"sigmoid", # loss_type
)

loss_module, _ = LigerFusedLinearDPOLoss(
ignore_index=-100,
beta=0.1,
alpha=1.0,
compute_nll_loss=True,
compiled=True,
use_ref_model=True,
average_log_prob=False,
chunk_size=1,
loss_type="sigmoid",
)(
_weight.detach().clone().requires_grad_(True),
_input.detach().clone().requires_grad_(True),
target,
None, # bias
ref_input,
_ref_weight.detach().clone().requires_grad_(True),
None, # ref_bias
)

assert_verbose_allclose(loss_positional, loss_module, atol=1e-5, rtol=1e-4)


def test_label_smoothing_validation():
"""Test that invalid label_smoothing values raise ValueError for the relevant loss types."""
with pytest.raises(ValueError, match="label_smoothing must be > 0 for loss_type='exo_pair'"):
Expand Down
Loading