forked from facebookresearch/PyTorch-BigGraph
-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
162 lines (135 loc) · 5.44 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE.txt file in the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Optional
import torch
from torch import nn as nn
from torch.nn import functional as F
from torchbiggraph.model import match_shape
from torchbiggraph.plugin import PluginRegistry
from torchbiggraph.types import FloatTensorType
class AbstractLossFunction(nn.Module, ABC):
"""Calculate weighted loss of scores for positive and negative pairs.
The inputs are a 1-D tensor of size P containing scores for positive pairs
of entities (i.e., those among which an edge exists) and a P x N tensor
containing scores for negative pairs (i.e., where no edge should exist). The
pairs of entities corresponding to pos_scores[i] and to neg_scores[i,j] have
at least one endpoint in common. The output is the loss value these scores
induce. If the method supports weighting (as is the case for the logistic
loss) all positive scores will be weighted by the same weight and so will
all the negative ones.
"""
def __init__(self, **kwargs):
# loss functions will default ignore any kwargs, but can ask for any
# specific kwargs of interest in their constructor
# FIXME: This is not ideal. Perhaps we should pass in the config
# or a subconfig instead?
super().__init__()
@abstractmethod
def forward(
self,
pos_scores: FloatTensorType,
neg_scores: FloatTensorType,
weight: Optional[FloatTensorType],
) -> FloatTensorType:
pass
LOSS_FUNCTIONS = PluginRegistry[AbstractLossFunction]()
@LOSS_FUNCTIONS.register_as("logistic")
class LogisticLossFunction(AbstractLossFunction):
def forward(
self,
pos_scores: FloatTensorType,
neg_scores: FloatTensorType,
weight: Optional[FloatTensorType],
) -> FloatTensorType:
num_pos = match_shape(pos_scores, -1)
num_neg = match_shape(neg_scores, num_pos, -1)
neg_weight = 1 / num_neg if num_neg > 0 else 0
if weight is not None:
match_shape(weight, num_pos)
pos_loss = F.binary_cross_entropy_with_logits(
pos_scores,
pos_scores.new_ones(()).expand(num_pos),
reduction="sum",
weight=weight,
)
neg_loss = F.binary_cross_entropy_with_logits(
neg_scores,
neg_scores.new_zeros(()).expand(num_pos, num_neg),
reduction="sum",
weight=weight.unsqueeze(-1) if weight is not None else None,
)
loss = pos_loss + neg_weight * neg_loss
return loss
@LOSS_FUNCTIONS.register_as("ranking")
class RankingLossFunction(AbstractLossFunction):
def __init__(self, *, margin, **kwargs):
super().__init__()
self.margin = margin
def forward(
self,
pos_scores: FloatTensorType,
neg_scores: FloatTensorType,
weight: Optional[FloatTensorType],
) -> FloatTensorType:
num_pos = match_shape(pos_scores, -1)
num_neg = match_shape(neg_scores, num_pos, -1)
# FIXME Workaround for https://github.com/pytorch/pytorch/issues/15223.
if num_pos == 0 or num_neg == 0:
return torch.zeros((), device=pos_scores.device, requires_grad=True)
if weight is not None:
match_shape(weight, num_pos)
loss_per_sample = F.margin_ranking_loss(
neg_scores,
pos_scores.unsqueeze(1),
target=pos_scores.new_full((1, 1), -1, dtype=torch.float),
margin=self.margin,
reduction="none",
)
loss = (loss_per_sample * weight.unsqueeze(-1)).sum()
else:
# more memory efficient way if no weights
loss = F.margin_ranking_loss(
neg_scores,
pos_scores.unsqueeze(1),
target=pos_scores.new_full((1, 1), -1, dtype=torch.float),
margin=self.margin,
reduction="sum",
)
return loss
@LOSS_FUNCTIONS.register_as("softmax")
class SoftmaxLossFunction(AbstractLossFunction):
def forward(
self,
pos_scores: FloatTensorType,
neg_scores: FloatTensorType,
weight: Optional[FloatTensorType],
) -> FloatTensorType:
num_pos = match_shape(pos_scores, -1)
num_neg = match_shape(neg_scores, num_pos, -1)
# FIXME Workaround for https://github.com/pytorch/pytorch/issues/15870
# and https://github.com/pytorch/pytorch/issues/15223.
if num_pos == 0 or num_neg == 0:
return torch.zeros((), device=pos_scores.device, requires_grad=True)
scores = torch.cat(
[pos_scores.unsqueeze(1), neg_scores.logsumexp(dim=1, keepdim=True)], dim=1
)
if weight is not None:
loss_per_sample = F.cross_entropy(
scores,
pos_scores.new_zeros((num_pos,), dtype=torch.long),
reduction="none",
)
match_shape(weight, num_pos)
loss_per_sample = loss_per_sample * weight
else:
loss_per_sample = F.cross_entropy(
scores,
pos_scores.new_zeros((num_pos,), dtype=torch.long),
reduction="sum",
)
return loss_per_sample.sum()