4
4
import torch
5
5
from torch .nn .modules .loss import _Loss
6
6
import torch .nn .functional as F
7
+ import torch .nn as nn
7
8
from enum import IntEnum
8
9
9
10
def stable_kl (logit , target , epsilon = 1e-6 , reduce = True ):
@@ -49,6 +50,7 @@ def forward(self, input, target, weight=None, ignore_index=-1):
49
50
loss = loss * self .alpha
50
51
return loss
51
52
53
+
52
54
class SeqCeCriterion (CeCriterion ):
53
55
def __init__ (self , alpha = 1.0 , name = 'Seq Cross Entropy Criterion' ):
54
56
super ().__init__ (alpha , name )
@@ -116,13 +118,13 @@ def __init__(self, alpha=1.0, name='KL Div Criterion'):
116
118
self .alpha = alpha
117
119
self .name = name
118
120
119
- def forward (self , input , target , weight = None , ignore_index = - 1 ):
121
+ def forward (self , input , target , weight = None , ignore_index = - 1 , reduction = 'batchmean' ):
120
122
"""input/target: logits
121
123
"""
122
124
input = input .float ()
123
125
target = target .float ()
124
- loss = F .kl_div (F .log_softmax (input , dim = - 1 , dtype = torch .float32 ), F .softmax (target .detach (), dim = - 1 , dtype = torch .float32 ), reduction = 'batchmean' ) + \
125
- F .kl_div (F .log_softmax (target , dim = - 1 , dtype = torch .float32 ), F .softmax (input .detach (), dim = - 1 , dtype = torch .float32 ), reduction = 'batchmean' )
126
+ loss = F .kl_div (F .log_softmax (input , dim = - 1 , dtype = torch .float32 ), F .softmax (target .detach (), dim = - 1 , dtype = torch .float32 ), reduction = reduction ) + \
127
+ F .kl_div (F .log_softmax (target , dim = - 1 , dtype = torch .float32 ), F .softmax (input .detach (), dim = - 1 , dtype = torch .float32 ), reduction = reduction )
126
128
loss = loss * self .alpha
127
129
return loss
128
130
@@ -142,6 +144,41 @@ def forward(self, input, target, weight=None, ignore_index=-1):
142
144
loss = loss * self .alpha
143
145
return loss
144
146
147
+ class JSCriterion (Criterion ):
148
+ def __init__ (self , alpha = 1.0 , name = 'JS Div Criterion' ):
149
+ super ().__init__ ()
150
+ self .alpha = alpha
151
+ self .name = name
152
+
153
+ def forward (self , input , target , weight = None , ignore_index = - 1 , reduction = 'batchmean' ):
154
+ """input/target: logits
155
+ """
156
+ input = input .float ()
157
+ target = target .float ()
158
+ m = F .softmax (target .detach (), dim = - 1 , dtype = torch .float32 ) + \
159
+ F .softmax (input .detach (), dim = - 1 , dtype = torch .float32 )
160
+ m = 0.5 * m
161
+ loss = F .kl_div (F .log_softmax (input , dim = - 1 , dtype = torch .float32 ), m , reduction = reduction ) + \
162
+ F .kl_div (F .log_softmax (target , dim = - 1 , dtype = torch .float32 ), m , reduction = reduction )
163
+ loss = loss * self .alpha
164
+ return loss
165
+
166
+ class HLCriterion (Criterion ):
167
+ def __init__ (self , alpha = 1.0 , name = 'Hellinger Criterion' ):
168
+ super ().__init__ ()
169
+ self .alpha = alpha
170
+ self .name = name
171
+
172
+ def forward (self , input , target , weight = None , ignore_index = - 1 , reduction = 'batchmean' ):
173
+ """input/target: logits
174
+ """
175
+ input = input .float ()
176
+ target = target .float ()
177
+ si = F .softmax (target .detach (), dim = - 1 , dtype = torch .float32 ).sqrt_ ()
178
+ st = F .softmax (input .detach (), dim = - 1 , dtype = torch .float32 ).sqrt_ ()
179
+ loss = F .mse_loss (si , st )
180
+ loss = loss * self .alpha
181
+ return loss
145
182
146
183
147
184
class RankCeCriterion (Criterion ):
@@ -202,42 +239,6 @@ def forward(self, input, target, weight=None, ignore_index=-1):
202
239
loss = loss * self .alpha
203
240
return loss
204
241
205
- class JSCriterion (Criterion ):
206
- def __init__ (self , alpha = 1.0 , name = 'JS Div Criterion' ):
207
- super ().__init__ ()
208
- self .alpha = alpha
209
- self .name = name
210
-
211
- def forward (self , input , target , weight = None , ignore_index = - 1 , reduction = 'batchmean' ):
212
- """input/target: logits
213
- """
214
- input = input .float ()
215
- target = target .float ()
216
- m = F .softmax (target .detach (), dim = - 1 , dtype = torch .float32 ) + \
217
- F .softmax (input .detach (), dim = - 1 , dtype = torch .float32 )
218
- m = 0.5 * m
219
- loss = F .kl_div (F .log_softmax (input , dim = - 1 , dtype = torch .float32 ), m , reduction = reduction ) + \
220
- F .kl_div (F .log_softmax (target , dim = - 1 , dtype = torch .float32 ), m , reduction = reduction )
221
- loss = loss * self .alpha
222
- return loss
223
-
224
- class HLCriterion (Criterion ):
225
- def __init__ (self , alpha = 1.0 , name = 'Hellinger Criterion' ):
226
- super ().__init__ ()
227
- self .alpha = alpha
228
- self .name = name
229
-
230
- def forward (self , input , target , weight = None , ignore_index = - 1 , reduction = 'batchmean' ):
231
- """input/target: logits
232
- """
233
- input = input .float ()
234
- target = target .float ()
235
- si = F .softmax (target .detach (), dim = - 1 , dtype = torch .float32 ).sqrt_ ()
236
- st = F .softmax (input .detach (), dim = - 1 , dtype = torch .float32 ).sqrt_ ()
237
- loss = F .mse_loss (si , st )
238
- loss = loss * self .alpha
239
- return loss
240
-
241
242
class LossCriterion (IntEnum ):
242
243
CeCriterion = 0
243
244
MseCriterion = 1
@@ -252,6 +253,7 @@ class LossCriterion(IntEnum):
252
253
JSCriterion = 10
253
254
HLCriterion = 11
254
255
256
+
255
257
LOSS_REGISTRY = {
256
258
LossCriterion .CeCriterion : CeCriterion ,
257
259
LossCriterion .MseCriterion : MseCriterion ,
0 commit comments