Skip to content

Commit

Permalink
Merge pull request torch#135 from nicholas-leonard/parallel
Browse files Browse the repository at this point in the history
Parallel, Container & cie
  • Loading branch information
soumith committed Jan 7, 2015
2 parents 81d2c42 + 5b19816 commit 675507d
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 183 deletions.
2 changes: 1 addition & 1 deletion Concat.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
local Concat, parent = torch.class('nn.Concat', 'nn.Container')

function Concat:__init(dimension)
parent.__init(self, dimension)
parent.__init(self)
self.size = torch.LongStorage()
self.dimension = dimension
end
Expand Down
63 changes: 2 additions & 61 deletions ConcatTable.lua
Original file line number Diff line number Diff line change
@@ -1,24 +1,11 @@
local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Module')
local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Container')

function ConcatTable:__init()
parent.__init(self)
self.modules = {}
self.output = {}
end

function ConcatTable:add(module)
table.insert(self.modules, module)
return self
end

function ConcatTable:get(index)
return self.modules[index]
end

function ConcatTable:size()
return #self.modules
end

function ConcatTable:updateOutput(input)
for i=1,#self.modules do
self.output[i] = self.modules[i]:updateOutput(input)
Expand Down Expand Up @@ -99,52 +86,6 @@ function ConcatTable:zeroGradParameters()
end
end

function ConcatTable:updateParameters(learningRate)
for _,module in ipairs(self.modules) do
module:updateParameters(learningRate)
end
end

function ConcatTable:training()
for i=1,#self.modules do
self.modules[i]:training()
end
end

function ConcatTable:evaluate()
for i=1,#self.modules do
self.modules[i]:evaluate()
end
end

function ConcatTable:share(mlp,...)
for i=1,#self.modules do
self.modules[i]:share(mlp.modules[i],...);
end
end

function ConcatTable:parameters()
local function tinsert(to, from)
if type(from) == 'table' then
for i=1,#from do
tinsert(to,from[i])
end
else
table.insert(to,from)
end
end
local w = {}
local gw = {}
for i=1,#self.modules do
local mw,mgw = self.modules[i]:parameters()
if mw then
tinsert(w,mw)
tinsert(gw,mgw)
end
end
return w,gw
end

function ConcatTable:type(type)
parent.type(self, type)
if torch.type(self.gradInput) == 'table' then
Expand All @@ -161,7 +102,7 @@ function ConcatTable:__tostring__()
local ext = ' | '
local extlast = ' '
local last = ' ... -> '
local str = 'nn.ConcatTable'
local str = torch.type(self)
str = str .. ' {' .. line .. tab .. 'input'
for i=1,#self.modules do
if i == self.modules then
Expand Down
3 changes: 1 addition & 2 deletions Container.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
-- This is code common to container modules, which are collections of
-- smaller constituent modules like Parallel, Sequential, etc.
local Container, parent =
torch.class('nn.Container', 'nn.Module')
local Container, parent = torch.class('nn.Container', 'nn.Module')

function Container:__init(...)
parent.__init(self, ...)
Expand Down
76 changes: 29 additions & 47 deletions Parallel.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,29 @@ function Parallel:__init(inputDimension,outputDimension)
end

function Parallel:updateOutput(input)

local modules=input:size(self.inputDimension)
local nModule=input:size(self.inputDimension)
local outputs = {}

for i=1,modules do
local currentOutput =
self.modules[i]:updateOutput(input:select(self.inputDimension,i))
for i=1,nModule do
local currentInput = input:select(self.inputDimension,i)
local currentOutput = self.modules[i]:updateOutput(currentInput)
table.insert(outputs, currentOutput)
local outputSize = currentOutput:size(self.outputDimension)

if i == 1 then
self.size:resize(currentOutput:dim()):copy(currentOutput:size())
else
self.size[self.outputDimension] = self.size[self.outputDimension]
+ currentOutput:size(self.outputDimension)
self.size[self.outputDimension] = self.size[self.outputDimension] + outputSize
end

end
self.output:resize(self.size)

local offset = 1
for i=1,modules do
local currentOutput = self.modules[i]:updateOutput(input:select(self.inputDimension,i))

self.output:narrow(self.outputDimension, offset,
currentOutput:size(self.outputDimension)):copy(currentOutput)
for i=1,nModule do
local currentOutput = outputs[i]
local outputSize = currentOutput:size(self.outputDimension)
self.output:narrow(self.outputDimension, offset, outputSize):copy(currentOutput)
offset = offset + currentOutput:size(self.outputDimension)
end
return self.output
Expand All @@ -42,15 +43,16 @@ function Parallel:updateGradInput(input, gradOutput)

local offset = 1
for i=1,nModule do
local module=self.modules[i];
local module=self.modules[i]
local currentInput = input:select(self.inputDimension,i)
local currentOutput = module.output
local currentGradInput =
module:updateGradInput(input:select(self.inputDimension,i),
gradOutput:narrow(self.outputDimension,
offset, currentOutput:size(self.outputDimension)))
local outputSize = currentOutput:size(self.outputDimension)
local currentGradOutput = gradOutput:narrow(self.outputDimension, offset, outputSize)

local currentGradInput = module:updateGradInput(currentInput, currentGradOutput)

self.gradInput:select(self.inputDimension,i):copy(currentGradInput)
offset = offset + currentOutput:size(self.outputDimension)
offset = offset + outputSize
end
return self.gradInput
end
Expand All @@ -60,16 +62,17 @@ function Parallel:accGradParameters(input, gradOutput, scale)

local offset = 1
for i=1,nModule do
local module = self.modules[i];
local module = self.modules[i]
local currentOutput = module.output
local outputSize = currentOutput:size(self.outputDimension)

module:accGradParameters(
input:select(self.inputDimension,i),
gradOutput:narrow(
self.outputDimension, offset,
currentOutput:size(self.outputDimension)),
scale)
gradOutput:narrow(self.outputDimension, offset,outputSize),
scale
)

offset = offset + currentOutput:size(self.outputDimension)
offset = offset + outputSize
end
end

Expand All @@ -81,6 +84,7 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr)
local module = self.modules[i];
local currentOutput = module.output
module:accUpdateGradParameters(

input:select(self.inputDimension,i),
gradOutput:narrow(self.outputDimension, offset,
currentOutput:size(self.outputDimension)),
Expand All @@ -89,28 +93,6 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr)
offset = offset + currentOutput:size(self.outputDimension)
end
end

function Parallel:parameters()
local function tinsert(to, from)
if type(from) == 'table' then
for i=1,#from do
tinsert(to,from[i])
end
else
table.insert(to,from)
end
end
local w = {}
local gw = {}
for i=1,#self.modules do
local mw,mgw = self.modules[i]:parameters()
if mw then
tinsert(w,mw)
tinsert(gw,mgw)
end
end
return w,gw
end

function Parallel:__tostring__()
local tab = ' '
Expand All @@ -119,7 +101,7 @@ function Parallel:__tostring__()
local ext = ' | '
local extlast = ' '
local last = ' ... -> '
local str = 'nn.Parallel'
local str = torch.type(self)
str = str .. ' {' .. line .. tab .. 'input'
for i=1,#self.modules do
if i == self.modules then
Expand Down
70 changes: 2 additions & 68 deletions ParallelTable.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
local ParallelTable, parent = torch.class('nn.ParallelTable', 'nn.Module')
local ParallelTable, parent = torch.class('nn.ParallelTable', 'nn.Container')

function ParallelTable:__init()
parent.__init(self)
Expand All @@ -7,27 +7,13 @@ function ParallelTable:__init()
self.gradInput = {}
end

function ParallelTable:add(module)
table.insert(self.modules, module)
return self
end

function ParallelTable:get(index)
return self.modules[index]
end

function ParallelTable:size()
return #self.modules
end

function ParallelTable:updateOutput(input)
for i=1,#self.modules do
self.output[i] = self.modules[i]:updateOutput(input[i])
end
return self.output
end


function ParallelTable:updateGradInput(input, gradOutput)
for i,module in ipairs(self.modules) do
self.gradInput[i]= module:updateGradInput(input[i], gradOutput[i])
Expand All @@ -49,66 +35,14 @@ function ParallelTable:accUpdateGradParameters(input, gradOutput, lr)
end
end

function ParallelTable:zeroGradParameters()
for _,module in ipairs(self.modules) do
module:zeroGradParameters()
end
end

function ParallelTable:updateParameters(learningRate)
for _,module in ipairs(self.modules) do
module:updateParameters(learningRate)
end
end

function ParallelTable:training()
for i=1,#self.modules do
self.modules[i]:training()
end
end

function ParallelTable:evaluate()
for i=1,#self.modules do
self.modules[i]:evaluate()
end
end

function ParallelTable:share(mlp,...)
for i=1,#self.modules do
self.modules[i]:share(mlp.modules[i],...);
end
end

function ParallelTable:parameters()
local function tinsert(to, from)
if type(from) == 'table' then
for i=1,#from do
tinsert(to,from[i])
end
else
table.insert(to,from)
end
end
local w = {}
local gw = {}
for i=1,#self.modules do
local mw,mgw = self.modules[i]:parameters()
if mw then
tinsert(w,mw)
tinsert(gw,mgw)
end
end
return w,gw
end

function ParallelTable:__tostring__()
local tab = ' '
local line = '\n'
local next = ' |`-> '
local ext = ' | '
local extlast = ' '
local last = ' ... -> '
local str = 'nn.ParallelTable'
local str = torch.type(self)
str = str .. ' {' .. line .. tab .. 'input'
for i=1,#self.modules do
if i == self.modules then
Expand Down
29 changes: 25 additions & 4 deletions doc/containers.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
<a name="nn.Containers"/>
# Containers #
Complex neural networks are easily built using container classes:
* [Sequential](#nn.Sequential) : plugs layers in a feed-forward fully connected manner ;
* [Parallel](#nn.Parallel) : applies its `ith` child module to the `ith` slice of the input Tensor ;
* [Concat](#nn.Concat) : concatenates in one layer several modules along dimension `dim` ;
* [DepthConcat](#nn.DepthConcat) : like Concat, but adds zero-padding when non-`dim` sizes don't match;
* [Container](#nn.Container) : abstract class inherited by containers ;
* [Sequential](#nn.Sequential) : plugs layers in a feed-forward fully connected manner ;
* [Parallel](#nn.Parallel) : applies its `ith` child module to the `ith` slice of the input Tensor ;
* [Concat](#nn.Concat) : concatenates in one layer several modules along dimension `dim` ;
* [DepthConcat](#nn.DepthConcat) : like Concat, but adds zero-padding when non-`dim` sizes don't match;

See also the [Table Containers](#nn.TableContainers) for manipulating tables of [Tensors](https://github.com/torch/torch7/blob/master/doc/tensor.md).

<a name="nn.Container"/>
## Container ##

This is an abstract [Module](module.md#nn.Module) class which declares methods defined in all containers.
It reimplements many of the Module methods such that calls are propagated to the
contained modules. For example, a call to [zeroGradParameters](module.md#nn.Module.zeroGradParameters)
will be propagated to all contained modules.

<a name="nn.Container.add"/>
### add(module) ###
Adds the given `module` to the container. The order is important

<a name="nn.Container.get"/>
### get(index) ###
Returns the contained modules at index `index`.

<a name="nn.Container.size"/>
### size() ###
Returns the number of contained modules.

<a name="nn.Sequential"/>
## Sequential ##

Expand Down
Loading

0 comments on commit 675507d

Please sign in to comment.