@@ -3116,6 +3116,45 @@ function nntest.MixtureTable()
3116
3116
end
3117
3117
end
3118
3118
3119
+ function nntest .NarrowTable ()
3120
+ local input = torch .randn (3 ,10 ,4 )
3121
+ local gradOutput = torch .randn (3 ,3 ,4 )
3122
+ local nt = nn .NarrowTable (5 ,3 )
3123
+ local seq = nn .Sequential ()
3124
+ seq :add (nn .SplitTable (1 ,2 ))
3125
+ seq :add (nt )
3126
+ seq :add (nn .JoinTable (1 ,1 ))
3127
+ seq :add (nn .Reshape (3 ,3 ,4 ))
3128
+ local seq2 = nn .Narrow (2 ,5 ,3 )
3129
+ local output = seq :forward (input )
3130
+ local gradInput = seq :backward (input , gradOutput )
3131
+ local output2 = seq2 :forward (input )
3132
+ local gradInput2 = seq2 :backward (input , gradOutput )
3133
+ mytester :assertTensorEq (output , output2 , 0.0000001 , " NarrowTable output err" )
3134
+ mytester :assertTensorEq (gradInput , gradInput2 , 0.00001 , " NarrowTable gradInput err" )
3135
+
3136
+ -- now try it with a smaller input
3137
+ local input = input :narrow (2 , 1 , 8 )
3138
+ local output = seq :forward (input )
3139
+ local gradInput = seq :backward (input , gradOutput )
3140
+ local output2 = seq2 :forward (input )
3141
+ local gradInput2 = seq2 :backward (input , gradOutput )
3142
+ mytester :assertTensorEq (output , output2 , 0.0000001 , " NarrowTable small output err" )
3143
+ mytester :assertTensorEq (gradInput , gradInput2 , 0.00001 , " NarrowTable small gradInput err" )
3144
+
3145
+ -- test type-cast
3146
+ local input = input :float ()
3147
+ local gradOutput = gradOutput :float ()
3148
+ seq :float ()
3149
+ seq2 :float ()
3150
+ local output = seq :forward (input )
3151
+ local gradInput = seq :backward (input , gradOutput )
3152
+ local output2 = seq2 :forward (input )
3153
+ local gradInput2 = seq2 :backward (input , gradOutput )
3154
+ mytester :assertTensorEq (output , output2 , 0.0000001 , " NarrowTable output float err" )
3155
+ mytester :assertTensorEq (gradInput , gradInput2 , 0.00001 , " NarrowTable gradInput float err" )
3156
+ end
3157
+
3119
3158
function nntest .View ()
3120
3159
local input = torch .rand (10 )
3121
3160
local template = torch .rand (5 ,2 )
0 commit comments