-
Notifications
You must be signed in to change notification settings - Fork 45
Open
Description
Thank you very much for your summary of the loss function in the field of NLP. And, I have a question about BinaryDSCLoss. I sincerely hope you can take time to answer my doubts.
This is your code:
def forward(self, logits, targets):
probs = torch.sigmoid(logits)
probs = torch.gather(probs, dim=1, index=targets.unsqueeze(1))
targets = targets.unsqueeze(dim=1)
pos_mask = (targets == 1).float()
neg_mask = (targets == 0).float()
pos_weight = pos_mask * ((1 - probs) ** self.alpha) * probs
pos_loss = 1 - (2 * pos_weight + self.smooth) / (pos_weight + 1 + self.smooth)
neg_weight = neg_mask * ((1 - probs) ** self.alpha) * probs
neg_loss = 1 - (2 * neg_weight + self.smooth) / (neg_weight + self.smooth)
loss = pos_loss + neg_loss
loss = loss.mean()
return loss
From the above code, we can see that you calculate loss for positive and negative examples respectively. But it doesn't seem to be calculated in the original paper.
Is this your improvement?
Metadata
Metadata
Assignees
Labels
No labels