Skip to content

Commit 59e3f8c

Browse files
NarrowTable
1 parent 514a093 commit 59e3f8c

File tree

5 files changed

+136
-0
lines changed

5 files changed

+136
-0
lines changed

NarrowTable.lua

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
local NarrowTable, parent = torch.class('nn.NarrowTable', 'nn.Module')
2+
3+
function NarrowTable:__init(offset, length)
4+
parent.__init(self)
5+
self.offset = offset
6+
self.length = length or 1
7+
if not offset then
8+
error('nn.NarrowTable(offset, length)')
9+
end
10+
11+
self.output = {}
12+
self.gradInput = {}
13+
end
14+
15+
function NarrowTable:updateOutput(input)
16+
for k,v in ipairs(self.output) do self.output[k] = nil end
17+
for i=1,self.length do
18+
self.output[i] = input[self.offset+i-1]
19+
end
20+
return self.output
21+
end
22+
23+
function NarrowTable:updateGradInput(input, gradOutput)
24+
for i=1,#gradOutput do
25+
self.gradInput[self.offset+i-1] = gradOutput[i]
26+
end
27+
for i=1,#input do
28+
if (i < self.offset) or (i >= self.offset + self.length) then
29+
self.gradInput[i] = nn.utils.recursiveResizeAs(self.gradInput[i], input[i])
30+
nn.utils.recursiveFill(self.gradInput[i], 0)
31+
end
32+
end
33+
for i=#input+1,#self.gradInput do self.gradInput[i] = nil end
34+
return self.gradInput
35+
end
36+
37+
function NarrowTable:type(type, tensorCache)
38+
self.output = {}
39+
self.gradInput = {}
40+
return parent.type(self, type, tensorCache)
41+
end

doc/table.md

+41
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ This allows one to build very rich architectures:
1111
* [`JoinTable`](#nn.JoinTable): joins a `table` of `Tensor`s into a `Tensor`;
1212
* [`MixtureTable`](#nn.MixtureTable): mixture of experts weighted by a gater;
1313
* [`SelectTable`](#nn.SelectTable): select one element from a `table`;
14+
* [`NarrowTable`](#nn.NarrowTable): select a slice of elements from a `table`;
1415
* [`FlattenTable`](#nn.FlattenTable): flattens a nested `table` hierarchy;
1516
* Pair Modules compute a measure like distance or similarity from a pair (`table`) of input `Tensor`s:
1617
* [`PairwiseDistance`](#nn.PairwiseDistance): outputs the `p`-norm. distance between inputs;
@@ -724,6 +725,46 @@ Example 2:
724725

725726
```
726727

728+
<a name="nn.NarrowTable"/>
729+
## NarrowTable ##
730+
731+
`module` = `NarrowTable(offset [, length])`
732+
733+
Creates a module that takes a `table` as input and outputs the subtable
734+
starting at index `offset` having `length` elements (defaults to 1 element).
735+
The elements can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor).
736+
737+
The gradients of the elements not included in the subtable are zeroed `Tensor`s of the same size.
738+
This is true regardless of the dept of the encapsulated `Tensor` as the function used internally to do so is recursive.
739+
740+
Example:
741+
```lua
742+
> input = {torch.randn(2, 3), torch.randn(2, 1), torch.randn(1, 2)}
743+
> =nn.NarrowTable(2,2):forward(input)
744+
{
745+
1 : DoubleTensor - size: 2x1
746+
2 : DoubleTensor - size: 1x2
747+
}
748+
749+
> =nn.NarrowTable(1):forward(input)
750+
{
751+
1 : DoubleTensor - size: 2x3
752+
}
753+
754+
> =table.unpack(nn.NarrowTable(1,2):backward(input, {torch.randn(2, 3), torch.randn(2, 1)}))
755+
1.9528 -0.1381 0.2023
756+
0.2297 -1.5169 -1.1871
757+
[torch.DoubleTensor of size 2x3]
758+
759+
-1.2023
760+
-0.4165
761+
[torch.DoubleTensor of size 2x1]
762+
763+
0 0
764+
[torch.DoubleTensor of size 1x2]
765+
766+
```
767+
727768
<a name="nn.FlattenTable"/>
728769
## FlattenTable ##
729770

init.lua

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ include('SelectTable.lua')
102102
include('MixtureTable.lua')
103103
include('CriterionTable.lua')
104104
include('FlattenTable.lua')
105+
include('NarrowTable.lua')
105106
include('Identity.lua')
106107

107108
include('Criterion.lua')

test.lua

+39
Original file line numberDiff line numberDiff line change
@@ -3116,6 +3116,45 @@ function nntest.MixtureTable()
31163116
end
31173117
end
31183118

3119+
function nntest.NarrowTable()
3120+
local input = torch.randn(3,10,4)
3121+
local gradOutput = torch.randn(3,3,4)
3122+
local nt = nn.NarrowTable(5,3)
3123+
local seq = nn.Sequential()
3124+
seq:add(nn.SplitTable(1,2))
3125+
seq:add(nt)
3126+
seq:add(nn.JoinTable(1,1))
3127+
seq:add(nn.Reshape(3,3,4))
3128+
local seq2 = nn.Narrow(2,5,3)
3129+
local output = seq:forward(input)
3130+
local gradInput = seq:backward(input, gradOutput)
3131+
local output2 = seq2:forward(input)
3132+
local gradInput2 = seq2:backward(input, gradOutput)
3133+
mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output err")
3134+
mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput err")
3135+
3136+
-- now try it with a smaller input
3137+
local input = input:narrow(2, 1, 8)
3138+
local output = seq:forward(input)
3139+
local gradInput = seq:backward(input, gradOutput)
3140+
local output2 = seq2:forward(input)
3141+
local gradInput2 = seq2:backward(input, gradOutput)
3142+
mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable small output err")
3143+
mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable small gradInput err")
3144+
3145+
-- test type-cast
3146+
local input = input:float()
3147+
local gradOutput = gradOutput:float()
3148+
seq:float()
3149+
seq2:float()
3150+
local output = seq:forward(input)
3151+
local gradInput = seq:backward(input, gradOutput)
3152+
local output2 = seq2:forward(input)
3153+
local gradInput2 = seq2:backward(input, gradOutput)
3154+
mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output float err")
3155+
mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput float err")
3156+
end
3157+
31193158
function nntest.View()
31203159
local input = torch.rand(10)
31213160
local template = torch.rand(5,2)

utils.lua

+14
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,19 @@ function nn.utils.recursiveResizeAs(t1,t2)
3030
return t1, t2
3131
end
3232

33+
function nn.utils.recursiveFill(t2, val)
34+
if torch.type(t2) == 'table' then
35+
for key,_ in pairs(t2) do
36+
t2[key] = nn.utils.recursiveFill(t2[key], val)
37+
end
38+
elseif torch.isTensor(t2) then
39+
t2:fill(val)
40+
else
41+
error("expecting tensor or table thereof. Got "
42+
..torch.type(t2).." instead")
43+
end
44+
return t2
45+
end
46+
3347

3448
table.unpack = table.unpack or unpack

0 commit comments

Comments
 (0)