Skip to content

Commit 4baec04

Browse files
ConcatTable can vary on any account
1 parent 8a57005 commit 4baec04

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

ConcatTable.lua

+6-14
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ local function retable(t1, t2, f)
3838
end
3939

4040
function ConcatTable:updateGradInput(input, gradOutput)
41-
if self.table or torch.type(input) == 'table' then
41+
local isTable = torch.type(input) == 'table'
42+
local wasTable = torch.type(self.gradInput) == 'table'
43+
if isTable then
4244
for i,module in ipairs(self.modules) do
4345
local currentGradInput = module:updateGradInput(input, gradOutput[i])
4446
if torch.type(currentGradInput) ~= 'table' then
@@ -48,21 +50,10 @@ function ConcatTable:updateGradInput(input, gradOutput)
4850
error("table size mismatch: "..#input.." ~= "..#currentGradInput)
4951
end
5052
if i == 1 then
51-
if not self.table then
52-
-- gradInput is also a table
53-
self.gradInput = {}
54-
local cloneFunc = function(t, k ,v)
55-
t[k] = v:clone()
56-
end
57-
retable(self.gradInput, input,
58-
function(t, k ,v)
59-
t[k] = v:clone()
60-
end
61-
)
62-
self.table = true
63-
end
53+
self.gradInput = wasTable and self.gradInput or {}
6454
retable(self.gradInput, currentGradInput,
6555
function(t, k, v)
56+
t[k] = t[k] or v:clone()
6657
t[k]:resizeAs(v)
6758
t[k]:copy(v)
6859
end
@@ -76,6 +67,7 @@ function ConcatTable:updateGradInput(input, gradOutput)
7667
end
7768
end
7869
else
70+
self.gradInput = (not wasTable) and self.gradInput or input:clone()
7971
for i,module in ipairs(self.modules) do
8072
local currentGradInput = module:updateGradInput(input, gradOutput[i])
8173
if i == 1 then

0 commit comments

Comments
 (0)