Skip to content

Commit 3c4f8c5

Browse files
nn.Container inheritance and torch.type(self) tostrings
1 parent 3dcecf7 commit 3c4f8c5

File tree

4 files changed

+5
-65
lines changed

4 files changed

+5
-65
lines changed

Concat.lua

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
local Concat, parent = torch.class('nn.Concat', 'nn.Container')
22

33
function Concat:__init(dimension)
4-
parent.__init(self, dimension)
4+
parent.__init(self)
55
self.size = torch.LongStorage()
66
self.dimension = dimension
77
end

ConcatTable.lua

+2-61
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,11 @@
1-
local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Module')
1+
local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Container')
22

33
function ConcatTable:__init()
44
parent.__init(self)
55
self.modules = {}
66
self.output = {}
77
end
88

9-
function ConcatTable:add(module)
10-
table.insert(self.modules, module)
11-
return self
12-
end
13-
14-
function ConcatTable:get(index)
15-
return self.modules[index]
16-
end
17-
18-
function ConcatTable:size()
19-
return #self.modules
20-
end
21-
229
function ConcatTable:updateOutput(input)
2310
for i=1,#self.modules do
2411
self.output[i] = self.modules[i]:updateOutput(input)
@@ -99,52 +86,6 @@ function ConcatTable:zeroGradParameters()
9986
end
10087
end
10188

102-
function ConcatTable:updateParameters(learningRate)
103-
for _,module in ipairs(self.modules) do
104-
module:updateParameters(learningRate)
105-
end
106-
end
107-
108-
function ConcatTable:training()
109-
for i=1,#self.modules do
110-
self.modules[i]:training()
111-
end
112-
end
113-
114-
function ConcatTable:evaluate()
115-
for i=1,#self.modules do
116-
self.modules[i]:evaluate()
117-
end
118-
end
119-
120-
function ConcatTable:share(mlp,...)
121-
for i=1,#self.modules do
122-
self.modules[i]:share(mlp.modules[i],...);
123-
end
124-
end
125-
126-
function ConcatTable:parameters()
127-
local function tinsert(to, from)
128-
if type(from) == 'table' then
129-
for i=1,#from do
130-
tinsert(to,from[i])
131-
end
132-
else
133-
table.insert(to,from)
134-
end
135-
end
136-
local w = {}
137-
local gw = {}
138-
for i=1,#self.modules do
139-
local mw,mgw = self.modules[i]:parameters()
140-
if mw then
141-
tinsert(w,mw)
142-
tinsert(gw,mgw)
143-
end
144-
end
145-
return w,gw
146-
end
147-
14889
function ConcatTable:type(type)
14990
parent.type(self, type)
15091
if torch.type(self.gradInput) == 'table' then
@@ -161,7 +102,7 @@ function ConcatTable:__tostring__()
161102
local ext = ' | '
162103
local extlast = ' '
163104
local last = ' ... -> '
164-
local str = 'nn.ConcatTable'
105+
local str = torch.type(self)
165106
str = str .. ' {' .. line .. tab .. 'input'
166107
for i=1,#self.modules do
167108
if i == self.modules then

Container.lua

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
-- This is code common to container modules, which are collections of
22
-- smaller constituent modules like Parallel, Sequential, etc.
3-
local Container, parent =
4-
torch.class('nn.Container', 'nn.Module')
3+
local Container, parent = torch.class('nn.Container', 'nn.Module')
54

65
function Container:__init(...)
76
parent.__init(self, ...)

Parallel.lua

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ function Parallel:__tostring__()
101101
local ext = ' | '
102102
local extlast = ' '
103103
local last = ' ... -> '
104-
local str = 'nn.Parallel'
104+
local str = torch.type(self)
105105
str = str .. ' {' .. line .. tab .. 'input'
106106
for i=1,#self.modules do
107107
if i == self.modules then

0 commit comments

Comments
 (0)