Skip to content

Commit 2e01032

Browse files
added optional last argument batchMode to nn.Reshape
1 parent 590a775 commit 2e01032

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

Reshape.lua

+8-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ function Reshape:__init(...)
66

77
self.size = torch.LongStorage()
88
self.batchsize = torch.LongStorage()
9+
if torch.type(arg[#arg]) == 'boolean' then
10+
self.batchMode = arg[#arg]
11+
table.remove(arg, #arg)
12+
end
913
local n = #arg
1014
if n == 1 and torch.typename(arg[1]) == 'torch.LongStorage' then
1115
self.size:resize(#arg[1]):copy(arg[1])
@@ -35,7 +39,10 @@ function Reshape:updateOutput(input)
3539
input = self._input
3640
end
3741

38-
if input:nElement() == self.nelement then
42+
if (self.batchMode == false) or (
43+
(self.batchMode == nil) and
44+
(input:nElement() == self.nelement and input:size(1) ~= 1)
45+
) then
3946
self.output:view(input, self.size)
4047
else
4148
self.batchsize[1] = input:size(1)

0 commit comments

Comments
 (0)