Skip to content

Commit

Permalink
fix inconsistent in loss function and training
Browse files Browse the repository at this point in the history
  • Loading branch information
coderbyr committed Aug 14, 2019
1 parent 3b33b0a commit 7a1b216
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
12 changes: 8 additions & 4 deletions model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def forward(self, logits, target,
use_hierar=False,
is_multi=False,
*argvs):
device = logits.device
if use_hierar:
assert self.loss_type in [LossType.BCE_WITH_LOGITS,
LossType.SIGMOID_FOCAL_CROSS_ENTROPY]
device = logits.device
if not is_multi:
target = torch.eye(self.label_size)[target].to(device)
hierar_penalty, hierar_paras, hierar_relations = argvs[0:3]
Expand All @@ -132,9 +132,13 @@ def forward(self, logits, target,
hierar_relations,
device)
else:
if not is_multi:
device = logits.device
target = torch.eye(self.label_size)[target].to(device)
if is_multi:
assert self.loss_type in [LossType.BCE_WITH_LOGITS,
LossType.SIGMOID_FOCAL_CROSS_ENTROPY]
else:
if self.loss_type not in [LossType.SOFTMAX_CROSS_ENTROPY,
LossType.SOFTMAX_FOCAL_CROSS_ENTROPY]
target = torch.eye(self.label_size)[target].to(device)
return self.criterion(logits, target)

def cal_recursive_regularize(self, paras, hierar_relations, device="cpu"):
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def run(self, data_loader, model, optimizer, stage,
else: # flat classification
loss = self.loss_fn(
logits,
batch[ClassificationDataset.DOC_LABEL].to(self.conf.device))
batch[ClassificationDataset.DOC_LABEL].to(self.conf.device),
False,
is_multi)
if mode == ModeType.TRAIN:
optimizer.zero_grad()
loss.backward()
Expand Down

0 comments on commit 7a1b216

Please sign in to comment.