@@ -38,7 +38,9 @@ local function retable(t1, t2, f)
38
38
end
39
39
40
40
function ConcatTable :updateGradInput (input , gradOutput )
41
- if self .table or torch .type (input ) == ' table' then
41
+ local isTable = torch .type (input ) == ' table'
42
+ local wasTable = torch .type (self .gradInput ) == ' table'
43
+ if isTable then
42
44
for i ,module in ipairs (self .modules ) do
43
45
local currentGradInput = module :updateGradInput (input , gradOutput [i ])
44
46
if torch .type (currentGradInput ) ~= ' table' then
@@ -48,21 +50,10 @@ function ConcatTable:updateGradInput(input, gradOutput)
48
50
error (" table size mismatch: " ..# input .. " ~= " ..# currentGradInput )
49
51
end
50
52
if i == 1 then
51
- if not self .table then
52
- -- gradInput is also a table
53
- self .gradInput = {}
54
- local cloneFunc = function (t , k ,v )
55
- t [k ] = v :clone ()
56
- end
57
- retable (self .gradInput , input ,
58
- function (t , k ,v )
59
- t [k ] = v :clone ()
60
- end
61
- )
62
- self .table = true
63
- end
53
+ self .gradInput = wasTable and self .gradInput or {}
64
54
retable (self .gradInput , currentGradInput ,
65
55
function (t , k , v )
56
+ t [k ] = t [k ] or v :clone ()
66
57
t [k ]:resizeAs (v )
67
58
t [k ]:copy (v )
68
59
end
@@ -76,6 +67,7 @@ function ConcatTable:updateGradInput(input, gradOutput)
76
67
end
77
68
end
78
69
else
70
+ self .gradInput = (not wasTable ) and self .gradInput or input :clone ()
79
71
for i ,module in ipairs (self .modules ) do
80
72
local currentGradInput = module :updateGradInput (input , gradOutput [i ])
81
73
if i == 1 then
0 commit comments