Skip to content

Commit 955e496

Browse files
MarginRankingCriterion:type()
1 parent cedc0b9 commit 955e496

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

MarginRankingCriterion.lua

+6
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,9 @@ function MarginRankingCriterion:updateGradInput(input, y)
6969
end
7070
return self.gradInput
7171
end
72+
73+
function MarginRankingCriterion:type(type)
74+
self.gradInput[1] = self.gradInput[1]:type(type)
75+
self.gradInput[2] = self.gradInput[2]:type(type)
76+
return parent.type(self, type)
77+
end

test.lua

+17
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,23 @@ function nntest.MultiMarginCriterion()
794794
criterionJacobianTest1D(cri, input, target)
795795
end
796796

797+
function nntest.MarginRankingCriterion()
798+
local input = {torch.rand(1), torch.rand(1)}
799+
local mrc = nn.MarginRankingCriterion()
800+
local output = mrc:forward(input, 1)
801+
local gradInput = mrc:backward(input, 1)
802+
-- cast to float
803+
local input2 = {input[1]:float(), input[2]:float()}
804+
local mrc2 = mrc:clone():float()
805+
local output2 = mrc2:forward(input2, 1)
806+
local gradInput2 = mrc2:backward(input2, 1)
807+
mytester:assert(math.abs(output2 - output) < 0.00001, "MRC:type() forward error")
808+
mytester:assertTensorEq(gradInput[1]:float(), gradInput2[1], 0.00001, "MRC:type() backward error 1")
809+
mytester:assert(torch.type(gradInput2[1]) == 'torch.FloatTensor', "MRC:type() error 1")
810+
mytester:assertTensorEq(gradInput[2]:float(), gradInput2[2], 0.00001, "MRC:type() backward error 2")
811+
mytester:assert(torch.type(gradInput2[2]) == 'torch.FloatTensor', "MRC:type() error 2")
812+
end
813+
797814
function nntest.WeightedMSECriterion()
798815
local input = torch.rand(10)
799816
local target = input:clone():add(torch.rand(10))

0 commit comments

Comments
 (0)