Skip to content

Commit c69257f

Browse files
LookupTable makeInputContiguous only once
1 parent 8fa3ee9 commit c69257f

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

LookupTable.lua

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
local LookupTable, parent = torch.class('nn.LookupTable', 'nn.Module')
22

3-
LookupTable.__version = 3
3+
LookupTable.__version = 4
44

55
function LookupTable:__init(nIndex, nOutput)
66
parent.__init(self)
@@ -33,9 +33,11 @@ end
3333
function LookupTable:makeInputContiguous(input)
3434
-- make sure input is a contiguous torch.LongTensor
3535
if (not input:isContiguous()) or torch.type(input) ~= torch.type(self._input) then
36+
self.copiedInput = true
3637
self._input:resize(input:size()):copy(input)
3738
return self._input
3839
end
40+
self.copiedInput = false
3941
return input
4042
end
4143

@@ -53,8 +55,7 @@ function LookupTable:updateOutput(input)
5355
end
5456

5557
function LookupTable:accGradParameters(input, gradOutput, scale)
56-
input = self:makeInputContiguous(input)
57-
self.gradWeight.nn.LookupTable_accGradParameters(self, input, gradOutput, scale)
58+
self.gradWeight.nn.LookupTable_accGradParameters(self, self.copiedInput and self._input or input, gradOutput, scale)
5859
end
5960

6061
function LookupTable:type(type)

0 commit comments

Comments
 (0)