-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathutils.py
More file actions
42 lines (30 loc) · 1.32 KB
/
utils.py
File metadata and controls
42 lines (30 loc) · 1.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn.functional as F
def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1):
"""
Filter a distribution of logits using nucleus (top-p) filtering
https://github.com/OpenLMLab/MOSS/blob/e088f438d1a95d424c6dffef0d73134ebe62cb72/models_jittor/generation.py#L146
"""
cum_logits = logits.clone()
if topp > 0:
logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
mask[:, :min_topk] = False
# Remove tokens with cumulative top_p above the threshold
mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask)
cum_logits[mask] = filter_value
cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True))
return cum_logits
def logprobs_from_logits(logits, labels):
"""
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
"""
logp = F.log_softmax(logits, dim=-1)
logpy = torch.gather(logp, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return logpy
def denoise_logits(logits1, logits2):
row_sums = logits1.sum(dim=1)
noise_indices = (row_sums == 0.0).nonzero(as_tuple=True)[0]
new_logits = logits1.clone()
new_logits[noise_indices] = logits2[noise_indices]
return new_logits