Skip to content

Commit cea3ab2

Browse files
fix missing gradWeight bug
1 parent 6577b55 commit cea3ab2

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

SpatialConvolution.lua

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ local function backCompatibility(self)
4949
if self.weight:dim() == 2 then
5050
self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
5151
end
52-
if self.gradWeight:dim() == 2 then
52+
if self.gradWeight and self.gradWeight:dim() == 2 then
5353
self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
5454
end
5555
end
@@ -73,12 +73,12 @@ end
7373
-- function to re-view the weight layout in a way that would make the MM ops happy
7474
local function viewWeight(self)
7575
self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW)
76-
self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW)
76+
self.gradWeight = self.gradWeight and self.gradWeight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW)
7777
end
7878

7979
local function unviewWeight(self)
8080
self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
81-
self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
81+
self.gradWeight = self.gradWeight and self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
8282
end
8383

8484
function SpatialConvolution:updateOutput(input)

0 commit comments

Comments
 (0)