1
- local ConcatTable , parent = torch .class (' nn.ConcatTable' , ' nn.Module ' )
1
+ local ConcatTable , parent = torch .class (' nn.ConcatTable' , ' nn.Container ' )
2
2
3
3
function ConcatTable :__init ()
4
4
parent .__init (self )
5
5
self .modules = {}
6
6
self .output = {}
7
7
end
8
8
9
- function ConcatTable :add (module )
10
- table.insert (self .modules , module )
11
- return self
12
- end
13
-
14
- function ConcatTable :get (index )
15
- return self .modules [index ]
16
- end
17
-
18
- function ConcatTable :size ()
19
- return # self .modules
20
- end
21
-
22
9
function ConcatTable :updateOutput (input )
23
10
for i = 1 ,# self .modules do
24
11
self .output [i ] = self .modules [i ]:updateOutput (input )
@@ -99,52 +86,6 @@ function ConcatTable:zeroGradParameters()
99
86
end
100
87
end
101
88
102
- function ConcatTable :updateParameters (learningRate )
103
- for _ ,module in ipairs (self .modules ) do
104
- module :updateParameters (learningRate )
105
- end
106
- end
107
-
108
- function ConcatTable :training ()
109
- for i = 1 ,# self .modules do
110
- self .modules [i ]:training ()
111
- end
112
- end
113
-
114
- function ConcatTable :evaluate ()
115
- for i = 1 ,# self .modules do
116
- self .modules [i ]:evaluate ()
117
- end
118
- end
119
-
120
- function ConcatTable :share (mlp ,...)
121
- for i = 1 ,# self .modules do
122
- self .modules [i ]:share (mlp .modules [i ],... );
123
- end
124
- end
125
-
126
- function ConcatTable :parameters ()
127
- local function tinsert (to , from )
128
- if type (from ) == ' table' then
129
- for i = 1 ,# from do
130
- tinsert (to ,from [i ])
131
- end
132
- else
133
- table.insert (to ,from )
134
- end
135
- end
136
- local w = {}
137
- local gw = {}
138
- for i = 1 ,# self .modules do
139
- local mw ,mgw = self .modules [i ]:parameters ()
140
- if mw then
141
- tinsert (w ,mw )
142
- tinsert (gw ,mgw )
143
- end
144
- end
145
- return w ,gw
146
- end
147
-
148
89
function ConcatTable :type (type )
149
90
parent .type (self , type )
150
91
if torch .type (self .gradInput ) == ' table' then
@@ -161,7 +102,7 @@ function ConcatTable:__tostring__()
161
102
local ext = ' | '
162
103
local extlast = ' '
163
104
local last = ' ... -> '
164
- local str = ' nn.ConcatTable '
105
+ local str = torch . type ( self )
165
106
str = str .. ' {' .. line .. tab .. ' input'
166
107
for i = 1 ,# self .modules do
167
108
if i == self .modules then
0 commit comments