forked from vt-vl-lab/SDN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhloss.py
22 lines (19 loc) · 789 Bytes
/
hloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
## code from https://discuss.pytorch.org/t/calculating-the-entropy-loss/14510
from torch import nn
import torch.nn.functional as F
class HLoss(nn.Module):
"""
returning the negative entropy of an input tensor
"""
def __init__(self, is_maximization=False):
super(HLoss, self).__init__()
self.is_neg = is_maximization
def forward(self, x):
b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
if self.is_neg:
# b = 1.0 * b.sum() # summation over batches
b = 1.0 * b.sum(dim=1).mean() # summation over batches, mean over batches
else:
# b = -1.0 * b.sum()
b = -1.0 * b.sum(dim=1).mean() # summation over batches, mean over batches
return b