Skip to content

Commit 682b9fb

Browse files
committed
update
1 parent 5bdccc8 commit 682b9fb

File tree

4 files changed

+44
-42
lines changed

4 files changed

+44
-42
lines changed

mt_dnn/batcher.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def __init__(self,
257257
max_seq_length=512,
258258
max_predictions_per_seq=80,
259259
printable=True):
260-
data, tokenizer = self.load(path, is_train, maxlen, factor, task_def, bert_model, do_lower_case)
260+
data, tokenizer = self.load(path, is_train, maxlen, factor, task_def, bert_model, do_lower_case, printable=printable)
261261
self._data = data
262262
self._tokenizer = tokenizer
263263
self._task_id = task_id

mt_dnn/loss.py

+41-39
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch.nn.modules.loss import _Loss
66
import torch.nn.functional as F
7+
import torch.nn as nn
78
from enum import IntEnum
89

910
def stable_kl(logit, target, epsilon=1e-6, reduce=True):
@@ -49,6 +50,7 @@ def forward(self, input, target, weight=None, ignore_index=-1):
4950
loss = loss * self.alpha
5051
return loss
5152

53+
5254
class SeqCeCriterion(CeCriterion):
5355
def __init__(self, alpha=1.0, name='Seq Cross Entropy Criterion'):
5456
super().__init__(alpha, name)
@@ -116,13 +118,13 @@ def __init__(self, alpha=1.0, name='KL Div Criterion'):
116118
self.alpha = alpha
117119
self.name = name
118120

119-
def forward(self, input, target, weight=None, ignore_index=-1):
121+
def forward(self, input, target, weight=None, ignore_index=-1, reduction='batchmean'):
120122
"""input/target: logits
121123
"""
122124
input = input.float()
123125
target = target.float()
124-
loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32), F.softmax(target.detach(), dim=-1, dtype=torch.float32), reduction='batchmean') + \
125-
F.kl_div(F.log_softmax(target, dim=-1, dtype=torch.float32), F.softmax(input.detach(), dim=-1, dtype=torch.float32), reduction='batchmean')
126+
loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32), F.softmax(target.detach(), dim=-1, dtype=torch.float32), reduction=reduction) + \
127+
F.kl_div(F.log_softmax(target, dim=-1, dtype=torch.float32), F.softmax(input.detach(), dim=-1, dtype=torch.float32), reduction=reduction)
126128
loss = loss * self.alpha
127129
return loss
128130

@@ -142,6 +144,41 @@ def forward(self, input, target, weight=None, ignore_index=-1):
142144
loss = loss * self.alpha
143145
return loss
144146

147+
class JSCriterion(Criterion):
148+
def __init__(self, alpha=1.0, name='JS Div Criterion'):
149+
super().__init__()
150+
self.alpha = alpha
151+
self.name = name
152+
153+
def forward(self, input, target, weight=None, ignore_index=-1, reduction='batchmean'):
154+
"""input/target: logits
155+
"""
156+
input = input.float()
157+
target = target.float()
158+
m = F.softmax(target.detach(), dim=-1, dtype=torch.float32) + \
159+
F.softmax(input.detach(), dim=-1, dtype=torch.float32)
160+
m = 0.5 * m
161+
loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32), m, reduction=reduction) + \
162+
F.kl_div(F.log_softmax(target, dim=-1, dtype=torch.float32), m, reduction=reduction)
163+
loss = loss * self.alpha
164+
return loss
165+
166+
class HLCriterion(Criterion):
167+
def __init__(self, alpha=1.0, name='Hellinger Criterion'):
168+
super().__init__()
169+
self.alpha = alpha
170+
self.name = name
171+
172+
def forward(self, input, target, weight=None, ignore_index=-1, reduction='batchmean'):
173+
"""input/target: logits
174+
"""
175+
input = input.float()
176+
target = target.float()
177+
si = F.softmax(target.detach(), dim=-1, dtype=torch.float32).sqrt_()
178+
st = F.softmax(input.detach(), dim=-1, dtype=torch.float32).sqrt_()
179+
loss = F.mse_loss(si, st)
180+
loss = loss * self.alpha
181+
return loss
145182

146183

147184
class RankCeCriterion(Criterion):
@@ -202,42 +239,6 @@ def forward(self, input, target, weight=None, ignore_index=-1):
202239
loss = loss * self.alpha
203240
return loss
204241

205-
class JSCriterion(Criterion):
206-
def __init__(self, alpha=1.0, name='JS Div Criterion'):
207-
super().__init__()
208-
self.alpha = alpha
209-
self.name = name
210-
211-
def forward(self, input, target, weight=None, ignore_index=-1, reduction='batchmean'):
212-
"""input/target: logits
213-
"""
214-
input = input.float()
215-
target = target.float()
216-
m = F.softmax(target.detach(), dim=-1, dtype=torch.float32) + \
217-
F.softmax(input.detach(), dim=-1, dtype=torch.float32)
218-
m = 0.5 * m
219-
loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32), m, reduction=reduction) + \
220-
F.kl_div(F.log_softmax(target, dim=-1, dtype=torch.float32), m, reduction=reduction)
221-
loss = loss * self.alpha
222-
return loss
223-
224-
class HLCriterion(Criterion):
225-
def __init__(self, alpha=1.0, name='Hellinger Criterion'):
226-
super().__init__()
227-
self.alpha = alpha
228-
self.name = name
229-
230-
def forward(self, input, target, weight=None, ignore_index=-1, reduction='batchmean'):
231-
"""input/target: logits
232-
"""
233-
input = input.float()
234-
target = target.float()
235-
si = F.softmax(target.detach(), dim=-1, dtype=torch.float32).sqrt_()
236-
st = F.softmax(input.detach(), dim=-1, dtype=torch.float32).sqrt_()
237-
loss = F.mse_loss(si, st)
238-
loss = loss * self.alpha
239-
return loss
240-
241242
class LossCriterion(IntEnum):
242243
CeCriterion = 0
243244
MseCriterion = 1
@@ -252,6 +253,7 @@ class LossCriterion(IntEnum):
252253
JSCriterion = 10
253254
HLCriterion = 11
254255

256+
255257
LOSS_REGISTRY = {
256258
LossCriterion.CeCriterion: CeCriterion,
257259
LossCriterion.MseCriterion: MseCriterion,

mt_dnn/perturbation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(self, model,
8888
if task_type == TaskType.Ranking:
8989
adv_logits = adv_logits.view(-1, pairwise)
9090
adv_loss = stable_kl(adv_logits, logits.detach(), reduce=False)
91-
delta_grad, = torch.autograd.grad(adv_loss, noise, only_inputs=True)
91+
delta_grad, = torch.autograd.grad(adv_loss, noise, only_inputs=True, retain_graph=False)
9292
norm = delta_grad.norm()
9393
if (torch.isnan(norm) or torch.isinf(norm)):
9494
return 0

train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def train_config(parser):
145145
parser.add_argument('--adv_p_norm', default='inf', type=str)
146146
parser.add_argument('--adv_alpha', default=1, type=float)
147147
parser.add_argument('--adv_k', default=1, type=int)
148-
parser.add_argument('--adv_step_size', default=1e-3, type=float)
148+
parser.add_argument('--adv_step_size', default=1e-5, type=float)
149149
parser.add_argument('--adv_noise_var', default=1e-5, type=float)
150150
parser.add_argument('--adv_epsilon', default=1e-6, type=float)
151151
parser.add_argument('--encode_mode', action='store_true', help="only encode test data")

0 commit comments

Comments
 (0)