Skip to content

Commit

Permalink
Update legacy ClassNLLCriterion to add ignore_index.
Browse files Browse the repository at this point in the history
  • Loading branch information
gchanan authored and soumith committed Aug 1, 2017
1 parent 61c873c commit 9c1e9d8
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions torch/legacy/nn/ClassNLLCriterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

class ClassNLLCriterion(Criterion):

def __init__(self, weights=None, sizeAverage=True):
def __init__(self, weights=None, sizeAverage=True, ignore_index=-100):
super(ClassNLLCriterion, self).__init__()
self.sizeAverage = sizeAverage
self.ignore_index = ignore_index

if weights is not None:
assert weights.dim() == 1
Expand All @@ -25,7 +26,7 @@ def updateOutput(self, input, target):
self.sizeAverage,
self.weights,
self.total_weight_tensor,
-100
self.ignore_index
)
self.output = self.output_tensor[0]
return self.output
Expand All @@ -42,7 +43,7 @@ def updateGradInput(self, input, target):
self.sizeAverage,
self.weights,
self.total_weight_tensor,
-100
self.ignore_index
)

return self.gradInput

0 comments on commit 9c1e9d8

Please sign in to comment.