From d42feec6c8ed444859cb3674be97201e3220cf9a Mon Sep 17 00:00:00 2001 From: dding3 Date: Tue, 25 Jul 2017 22:17:34 -0700 Subject: [PATCH] Convlstm3d (#1361) * support convlstm3d --- docs/docs/APIdocs/Layers/Recurrent-Layers.md | 500 +++++++++++++++++- pyspark/bigdl/nn/layer.py | 38 +- .../com/intel/analytics/bigdl/nn/Cell.scala | 21 +- ...eephole.scala => ConvLSTMPeephole2D.scala} | 14 +- .../bigdl/nn/ConvLSTMPeephole3D.scala | 234 ++++++++ .../intel/analytics/bigdl/nn/Recurrent.scala | 33 +- .../bigdl/python/api/PythonBigDL.scala | 20 +- .../intel/analytics/bigdl/nn/CellSpec.scala | 2 +- ...pec.scala => ConvLSTMPeephole2DSpec.scala} | 16 +- .../bigdl/nn/ConvLSTMPeephole3DSpec.scala | 53 ++ 10 files changed, 874 insertions(+), 57 deletions(-) rename spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/{ConvLSTMPeephole.scala => ConvLSTMPeephole2D.scala} (95%) create mode 100644 spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole3D.scala rename spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/{ConvLSTMPeepholeSpec.scala => ConvLSTMPeephole2DSpec.scala} (98%) create mode 100644 spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole3DSpec.scala diff --git a/docs/docs/APIdocs/Layers/Recurrent-Layers.md b/docs/docs/APIdocs/Layers/Recurrent-Layers.md index c37450cccfa..b7f6dd99f09 100644 --- a/docs/docs/APIdocs/Layers/Recurrent-Layers.md +++ b/docs/docs/APIdocs/Layers/Recurrent-Layers.md @@ -582,11 +582,11 @@ gradient = model.backward(input, grad_output) [-0.32718879 0.32963118]]] ``` --- -## ConvLSTMPeephole ## +## ConvLSTMPeephole2D ## **Scala:** ```scala -val model = ConvLSTMPeephole( +val model = ConvLSTMPeephole2D( inputSize = 2, outputSize = 4, kernelI = 3, @@ -600,7 +600,7 @@ val model = ConvLSTMPeephole( **Python:** ```python -model = ConvLSTMPeephole( +model = ConvLSTMPeephole2D( input_size = 2, output_size = 4, kernel_i = 3, @@ -612,7 +612,7 @@ model = ConvLSTMPeephole( with_peephole = True) ``` -Convolution Long Short Term Memory architecture with peephole. +Convolution Long Short Term Memory architecture with peephole for 2 dimension images. Ref. 1. https://arxiv.org/abs/1506.04214 (blueprint for this module) @@ -650,7 +650,7 @@ val input = Tensor(Array(batchSize, seqLength, inputSize, 3, 3)).rand() val rec = Recurrent() val model = Sequential() .add(rec - .add(ConvLSTMPeephole(inputSize, outputSize, 3, 3, 1, withPeephole = false))) + .add(ConvLSTMPeephole2D(inputSize, outputSize, 3, 3, 1, withPeephole = false))) val output = model.forward(input).toTensor @@ -744,7 +744,7 @@ batch_size = 1 input = np.random.randn(batch_size, seq_len, input_size, 3, 3) rec = Recurrent() model = Sequential().add( - rec.add(ConvLSTMPeephole(input_size, output_size, 3, 3, 1, with_peephole = False))) + rec.add(ConvLSTMPeephole2D(input_size, output_size, 3, 3, 1, with_peephole = False))) output = model.forward(input) >>> print(input) @@ -808,6 +808,494 @@ output = model.forward(input) [-0.15700462 -0.17341313 -0.06551415]]]]] ``` +--- +## ConvLSTMPeephole3D ## + +**Scala:** +```scala +val model = ConvLSTMPeephole3D( + inputSize = 2, + outputSize = 4, + kernelI = 3, + kernelC = 3, + stride = 1, + wRegularizer = null, + uRegularizer = null, + bRegularizer = null, + withPeephole = true) +``` + +**Python:** +```python +model = ConvLSTMPeephole3D( + input_size = 2, + output_size = 4, + kernel_i = 3, + kernel_c = 3, + stride = 1, + wRegularizer=None, + uRegularizer=None, + bRegularizer=None, + with_peephole = True) +``` + +Similar to Convlstm2D, it's a Convolution Long Short Term Memory architecture with peephole but for 3 spatial dimension images. + +Parameters: + +* `inputSize` number of input planes in the image given into forward() +* `outputSize` number of output planes the convolution layer will produce +* `kernelI` convolutional filter size to convolve input +* `kernelC` convolutional filter size to convolve cell +* `stride` step of the convolution +* `wRegularizer` instance of [[Regularizer]] + (eg. L1 or L2 regularization), applied to the input weights matrices. +* `uRegularizer` instance [[Regularizer]] + (eg. L1 or L2 regularization), applied to the recurrent weights matrices. +* `bRegularizer` instance of [[Regularizer]] + applied to the bias. +* `withPeephole` whether use last cell status control a gate + +**Scala example:** +```scala +import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.NumericFloat +import com.intel.analytics.bigdl.nn._ +import com.intel.analytics.bigdl.tensor.Tensor +import com.intel.analytics.bigdl.utils.RandomGenerator._ + +val outputSize = 4 +val inputSize = 3 +val seqLength = 2 +val batchSize = 1 + +val input = Tensor(Array(batchSize, seqLength, inputSize, 3, 3, 3)).rand() + +val rec = Recurrent() + val model = Sequential() + .add(rec + .add(ConvLSTMPeephole3D(inputSize, outputSize, 3, 3, 1, withPeephole = false))) + +val output = model.forward(input).toTensor + +scala> print(input) +(1,1,1,1,.,.) = +0.42592695 0.32742274 0.7926296 +0.21923159 0.7427106 0.31764257 +0.121872835 0.54231954 0.32091624 + +(1,1,1,2,.,.) = +0.06762145 0.8054027 0.8297814 +0.95535785 0.20807801 0.46387103 +0.90996957 0.7849159 0.79179865 + +(1,1,1,3,.,.) = +0.22927228 0.29869995 0.1145133 +0.12646529 0.8917339 0.7545332 +0.8044227 0.5340327 0.9784876 + +(1,1,2,1,.,.) = +0.68444395 0.47932255 0.28224406 +0.5083046 0.9364489 0.27006733 +0.24699332 0.55712855 0.50037974 + +(1,1,2,2,.,.) = +0.46334672 0.10979338 0.6378528 +0.8557069 0.10780747 0.73767877 +0.12505454 0.72492164 0.5440267 + +(1,1,2,3,.,.) = +0.15598479 0.52033675 0.64091414 +0.15149859 0.64515823 0.6023936 +0.31461328 0.1901752 0.98015004 + +(1,1,3,1,.,.) = +0.9700778 0.24109624 0.23764393 +0.16602103 0.97310185 0.072756775 +0.849201 0.825025 0.2753475 + +(1,1,3,2,.,.) = +0.8621034 0.24596989 0.56645423 +0.004375741 0.9873366 0.89219636 +0.56948274 0.291723 0.5503815 + +(1,1,3,3,.,.) = +0.626368 0.9389012 0.8974684 +0.8553843 0.39709046 0.372683 +0.38087663 0.94703597 0.71530545 + +(1,2,1,1,.,.) = +0.74050623 0.39862877 0.57509166 +0.87832487 0.41345102 0.6262451 +0.665165 0.49570015 0.8304163 + +(1,2,1,2,.,.) = +0.30847755 0.51876235 0.10555197 +0.10103849 0.9479695 0.11847988 +0.60081536 0.003097216 0.22800316 + +(1,2,1,3,.,.) = +0.113101795 0.76638913 0.091707565 +0.30347276 0.029687135 0.37973404 +0.67719024 0.02180517 0.12747364 + +(1,2,2,1,.,.) = +0.12513511 0.74210113 0.82569206 +0.1406212 0.7400157 0.041633762 +0.26903376 0.6195371 0.618376 + +(1,2,2,2,.,.) = +0.068732955 0.09746146 0.15479624 +0.57418007 0.7181547 0.6494809 +0.29213288 0.35022008 0.15421997 + +(1,2,2,3,.,.) = +0.47196773 0.55650383 0.938309 +0.70717365 0.68351734 0.32646814 +0.99775004 0.2596666 0.6803594 + +(1,2,3,1,.,.) = +0.6320722 0.105437785 0.36752152 +0.8347324 0.38376364 0.641918 +0.40254018 0.5421287 0.792421 + +(1,2,3,2,.,.) = +0.2652298 0.6261154 0.21971565 +0.31418183 0.44987184 0.43880364 +0.76821107 0.17070894 0.47295105 + +(1,2,3,3,.,.) = +0.16514553 0.37016368 0.23397927 +0.19776458 0.07518195 0.48995376 +0.13584352 0.23562871 0.41726747 + +[com.intel.analytics.bigdl.tensor.DenseTensor$mcF$sp of size 1x2x3x3x3x3] + +scala> print(output) +(1,1,1,1,.,.) = +0.014528348 0.03160259 0.05313618 +-0.011796958 0.027994404 0.028153816 +-0.010374474 0.029486801 0.033610236 + +(1,1,1,2,.,.) = +0.07966786 0.041255455 0.09181337 +0.025984935 0.06594588 0.07572434 +0.019637575 0.0068716113 0.03775029 + +(1,1,1,3,.,.) = +0.07043511 0.044567406 0.08229201 +0.10589862 0.109124646 0.0888148 +0.018544039 0.04097363 0.09130414 + +(1,1,2,1,.,.) = +0.1032162 -0.01981514 -0.0016546922 +0.026028564 0.0100736385 0.009424217 +-0.048695907 -0.009172593 -0.029458746 + +(1,1,2,2,.,.) = +0.058081806 0.101963215 0.056670886 +0.09300327 0.035424378 0.02410931 +0.056604195 -0.0032351227 0.027961217 + +(1,1,2,3,.,.) = +0.11710516 0.09371774 -0.013825272 +0.02930173 0.06391968 0.04034334 +0.010447707 -0.004905071 0.011929871 + +(1,1,3,1,.,.) = +-0.020980358 0.08554982 -0.07644813 +0.06367171 -0.06037125 0.019925931 +0.0026421212 0.051610045 0.023478134 + +(1,1,3,2,.,.) = +-0.033074334 -0.0381583 -0.019341394 +-0.0625153 -0.06907081 -0.019746307 +-0.010362335 0.0062695937 0.054116223 + +(1,1,3,3,.,.) = +0.00461099 -0.03308314 -6.8137434E-4 +-0.075023845 -0.024970314 0.008133534 +0.019836657 0.051302493 0.043689556 + +(1,1,4,1,.,.) = +0.027088374 0.008537832 -0.020948375 +0.021569671 0.016515112 -0.019221392 +-0.0074050943 -0.03274501 0.003256779 + +(1,1,4,2,.,.) = +8.967657E-4 0.019020535 -0.05990117 +0.06226491 -0.017516658 -0.028854925 +0.048010994 0.031080479 -4.8373322E-4 + +(1,1,4,3,.,.) = +0.03253352 -0.023469497 -0.047273926 +-0.03765316 0.011091222 0.0036612307 +0.050733108 0.01736545 0.0061482657 + +(1,2,1,1,.,.) = +-0.0037416879 0.03895818 0.102294624 +0.011019588 0.03201482 0.07654998 +-0.015550408 0.009587483 0.027655594 + +(1,2,1,2,.,.) = +0.089279816 0.03306113 0.11713534 +0.07299529 0.057692382 0.11090511 +-0.0031341386 0.091527686 0.07210587 + +(1,2,1,3,.,.) = +0.080724075 0.07707712 0.07624206 +0.06552311 0.104010254 0.09213451 +0.07030998 0.0022800618 0.12461836 + +(1,2,2,1,.,.) = +0.10180804 0.020320226 -0.0025817656 +0.016294254 -0.024293585 -0.004399727 +-0.032854877 1.1120379E-4 -0.02109197 + +(1,2,2,2,.,.) = +0.0968586 0.07098973 0.07648221 +0.0918679 0.10268471 0.056947876 +0.027774762 -0.03927014 0.04663368 + +(1,2,2,3,.,.) = +0.10225944 0.08460646 -8.393754E-4 +0.051307157 0.011988232 0.037762236 +0.029469138 0.023369621 0.037675448 + +(1,2,3,1,.,.) = +-0.017874755 0.08561468 -0.066132575 +0.010558257 -0.01448278 0.0073027355 +-0.007930762 0.052643955 0.008378773 + +(1,2,3,2,.,.) = +-0.009250246 -0.06543376 -0.025082456 +-0.093004115 -0.08637037 -0.063408665 +-0.06941878 0.010163672 0.07595171 + +(1,2,3,3,.,.) = +0.014756428 -0.040423956 -0.011537984 +-0.046337806 -0.008416044 0.068246834 +3.5782385E-4 0.056929104 0.052956138 + +(1,2,4,1,.,.) = +0.033539586 0.013915413 -0.024538055 +0.042590756 0.034134552 0.021031722 +-0.026687687 0.0012957935 -0.0053077694 + +(1,2,4,2,.,.) = +0.0033482902 -0.037335612 -0.0956953 +0.007350738 -0.05237038 -0.08849126 +0.016356941 0.032067236 -0.0012172575 + +(1,2,4,3,.,.) = +-0.020006038 -0.030038685 -0.054900024 +-0.014171911 0.01270077 -0.004130667 +0.04607582 0.040028486 0.011846061 + +[com.intel.analytics.bigdl.tensor.DenseTensor of size 1x2x4x3x3x3] + +``` +**Python example:** +```python +from bigdl.nn.layer import * +import numpy as np + +output_size = 4 +input_size= 3 +seq_len = 2 +batch_size = 1 + +input = np.random.randn(batch_size, seq_len, input_size, 3, 3, 3) +rec = Recurrent() +model = Sequential().add( + rec.add(ConvLSTMPeephole3D(input_size, output_size, 3, 3, 1, with_peephole = False))) +output = model.forward(input) + +>>> print(input) +[[[[[[ -8.92954769e-02 -9.77685543e-03 1.97566296e+00] + [ -5.76910662e-01 -9.08404346e-01 -4.70799006e-01] + [ -9.86229768e-01 7.87303916e-01 2.29691167e+00]] + + [[ -7.48240036e-01 4.12766483e-01 -3.88947296e-01] + [ -1.39879028e+00 2.43984720e+00 -2.43947000e-01] + [ 1.86468980e-01 1.34599111e+00 -6.97932324e-01]] + + [[ 1.23278710e+00 -4.02661913e-01 8.50721265e-01] + [ -1.79452089e-01 -5.58813385e-01 1.10060751e+00] + [ -6.27181580e-01 -2.69531726e-01 -1.07857962e-01]]] + + + [[[ -1.01462355e+00 5.47520811e-02 3.06976674e-01] + [ 9.64871158e-01 -1.16953916e+00 1.41880629e+00] + [ 1.19127007e+00 1.71403439e-01 -1.30787798e+00]] + + [[ -6.44313121e-01 -8.45131087e-01 6.99275525e-02] + [ -3.07656855e-01 1.25746926e+00 3.89980508e-02] + [ -2.59853355e-01 8.78915612e-01 -9.37204072e-02]] + + [[ 7.69958423e-02 -3.22523203e-01 -7.31295167e-01] + [ 1.46184856e+00 1.88641278e+00 1.46645372e-01] + [ 4.38390570e-01 -2.85102515e-01 -1.81269541e+00]]] + + + [[[ 2.95126419e-01 -1.13715815e+00 9.36848777e-01] + [ -1.62071909e+00 -1.06018926e+00 1.88416944e+00] + [ -5.81248254e-01 1.05162543e+00 -3.58790528e-01]] + + [[ -7.54710826e-01 2.29994522e+00 7.24276828e-01] + [ 5.77031441e-01 7.36132125e-01 2.24719266e+00] + [ -4.53710071e-05 1.98478259e-01 -2.62825655e-01]] + + [[ 1.68124733e+00 -9.97417864e-01 -3.73490116e-01] + [ -1.12558844e+00 2.60032255e-01 9.67994680e-01] + [ 1.78486852e+00 1.17514142e+00 -1.96871551e-01]]]] + + + + [[[[ 4.43156770e-01 -4.42279658e-01 8.00893010e-01] + [ -2.04817319e-01 -3.89658940e-01 -1.10950351e+00] + [ 6.61008455e-01 -4.07251176e-01 1.14871901e+00]] + + [[ -2.07785815e-01 -8.92450022e-01 -4.23830113e-02] + [ -5.26555807e-01 3.76671145e-02 -2.17877979e-01] + [ -7.68371469e-01 1.53052409e-01 1.02405949e+00]] + + [[ 5.75018628e-01 -9.47162716e-01 6.47917376e-01] + [ 4.66967303e-01 1.00917068e-01 -1.60894238e+00] + [ -1.46491032e-01 3.17782758e+00 1.12581079e-01]]] + + + [[[ 9.32343396e-01 -1.03853742e+00 5.67577254e-02] + [ 1.25266813e+00 3.52463164e-01 -1.86783652e-01] + [ -1.20321270e+00 3.95144053e-01 2.09975625e-01]] + + [[ 2.68240844e-01 -1.34931544e+00 1.34259455e+00] + [ 6.34339337e-01 -5.21231073e-02 -3.91895492e-01] + [ 1.53872699e-01 -5.07236962e-02 -2.90772390e-01]] + + [[ -5.07933749e-01 3.78036493e-01 7.41781186e-01] + [ 1.62736825e+00 1.24125644e+00 -3.97490478e-01] + [ 5.77762257e-01 1.10372911e+00 1.58060183e-01]]] + + + [[[ 5.31859839e-01 1.72805654e+00 -3.77124271e-01] + [ 1.24638369e+00 -1.54061928e+00 6.22001793e-01] + [ 1.92447446e+00 7.71351435e-01 -1.59998400e+00]] + + [[ 1.44289958e+00 5.41433535e-01 9.19769038e-01] + [ 9.92873720e-01 -9.05746035e-01 1.35906705e+00] + [ 1.38994943e+00 2.11451648e+00 -1.58783119e-01]] + + [[ -1.44024889e+00 -5.12269041e-01 8.56761529e-02] + [ 1.16668889e+00 7.58164067e-01 -1.04304927e+00] + [ 6.34138215e-01 -7.89939971e-01 -5.52376307e-01]]]]]] + +>>> print(output) +[[[[[[ 0.08801123 -0.15533912 -0.08897342] + [ 0.01158205 -0.01103314 0.02793931] + [-0.01269898 -0.09544773 0.03573112]] + + [[-0.15603164 -0.16063154 -0.09672774] + [ 0.15531734 0.05808824 -0.01653268] + [-0.06348733 -0.10497692 -0.13086422]] + + [[ 0.002062 -0.01604773 -0.14802884] + [-0.0934701 -0.06831796 0.07375477] + [-0.01157693 0.17962074 0.13433206]]] + + + [[[ 0.03571969 -0.20905718 -0.05286504] + [-0.18766534 -0.10728011 0.04605131] + [-0.07477143 0.02631984 0.02496208]] + + [[ 0.06653454 0.06536704 0.01587131] + [-0.00348636 -0.04439256 0.12680793] + [ 0.00328905 0.01904229 -0.06607334]] + + [[-0.04666118 -0.06754828 0.07643934] + [-0.05434367 -0.09878142 0.06385987] + [ 0.02643086 -0.01466259 -0.1031612 ]]] + + + [[[-0.0572568 0.13133277 -0.0435285 ] + [-0.11612531 0.09036689 -0.09608591] + [-0.01049453 -0.02091818 -0.00642477]] + + [[ 0.1255362 -0.07545673 -0.07554446] + [ 0.07270454 -0.24932131 -0.13024282] + [ 0.05507039 -0.0109083 0.00408967]] + + [[-0.1099453 -0.11417828 0.06235902] + [ 0.03701246 -0.02138007 -0.05719795] + [-0.02627739 -0.15853535 -0.01103899]]] + + + [[[ 0.10380347 -0.05826453 -0.00690799] + [ 0.01000955 -0.11808137 -0.039118 ] + [ 0.02591963 -0.03464907 -0.21320052]] + + [[-0.03449376 -0.00601143 0.05562805] + [ 0.09242225 0.01035819 0.09432289] + [-0.12854564 0.189775 -0.06698175]] + + [[ 0.03462109 0.02545513 -0.14716192] + [ 0.02003146 -0.03616474 0.04574323] + [ 0.04782774 -0.04594192 0.01773669]]]] + + + + [[[[ 0.04205685 -0.05454008 -0.0389443 ] + [ 0.07172828 0.03370164 0.00703573] + [ 0.01299563 -0.06371058 0.02505058]] + + [[-0.09191396 0.06227853 -0.15412274] + [ 0.09069916 0.01907965 -0.05783302] + [-0.03441796 -0.11438221 -0.1011953 ]] + + [[-0.00837748 -0.06554071 -0.14735688] + [-0.04640726 0.01484136 0.14445931] + [-0.09255736 -0.12196805 -0.0444463 ]]] + + + [[[ 0.01632853 0.01925437 0.02539274] + [-0.09239745 -0.13713452 0.06149488] + [-0.01742462 0.06624916 0.01490385]] + + [[ 0.03866836 0.19375585 0.06069621] + [-0.11291414 -0.29582706 0.11678439] + [-0.09451667 0.05238266 -0.05152772]] + + [[-0.11206269 0.09128021 0.09243178] + [ 0.01127258 -0.05845089 0.09795895] + [ 0.00747248 0.02055444 0.0121724 ]]] + + + [[[-0.11144694 -0.0030012 -0.03507657] + [-0.15461211 -0.00992483 0.02500556] + [-0.07733752 -0.09037463 0.02955181]] + + [[-0.00988597 0.0264726 -0.14286363] + [-0.06936073 -0.01345975 -0.16290392] + [-0.07821255 -0.02489748 0.05186536]] + + [[-0.12142604 0.04658077 0.00509979] + [-0.16115788 -0.19458961 -0.04082467] + [ 0.10544231 -0.10425973 0.01532217]]] + + + [[[ 0.08169251 0.05370622 0.00506061] + [ 0.08195242 0.08890768 0.03178475] + [-0.03648232 0.02655745 -0.18274172]] + + [[ 0.07358464 -0.09604233 0.06556321] + [-0.02229194 0.17364709 0.07240117] + [-0.18307404 0.04115544 -0.15400645]] + + [[ 0.0156146 -0.15857749 -0.12837477] + [ 0.07957774 0.06684072 0.0719762 ] + [-0.13781127 -0.03935293 -0.096707 ]]]]]] + +``` + --- ## TimeDistributed ## diff --git a/pyspark/bigdl/nn/layer.py b/pyspark/bigdl/nn/layer.py index 50a9a6cc010..19ab2061fff 100644 --- a/pyspark/bigdl/nn/layer.py +++ b/pyspark/bigdl/nn/layer.py @@ -3617,7 +3617,7 @@ class Pack(Layer): def __init__(self, dimension, bigdl_type="float"): super(Pack, self).__init__(None, bigdl_type, dimension) -class ConvLSTMPeephole(Layer): +class ConvLSTMPeephole2D(Layer): ''' | Convolution Long Short Term Memory architecture with peephole. @@ -3632,19 +3632,45 @@ class ConvLSTMPeephole(Layer): :param wRegularizer: instance of [[Regularizer]](eg. L1 or L2 regularization), applied to the input weights matrices :param uRegularizer: instance [[Regularizer]](eg. L1 or L2 regularization), applied to the recurrent weights matrices :param bRegularizer: instance of [[Regularizer]]applied to the bias. - :param with_peephold: whether use last cell status control a gate. + :param with_peephole: whether use last cell status control a gate. - >>> convlstm = ConvLSTMPeephole(4, 3, 3, 3, 1, L1Regularizer(0.5), L1Regularizer(0.5), L1Regularizer(0.5)) + >>> convlstm = ConvLSTMPeephole2D(4, 3, 3, 3, 1, L1Regularizer(0.5), L1Regularizer(0.5), L1Regularizer(0.5)) creating: createL1Regularizer creating: createL1Regularizer creating: createL1Regularizer - creating: createConvLSTMPeephole + creating: createConvLSTMPeephole2D ''' def __init__(self, input_size, output_size, kernel_i, kernel_c, stride, wRegularizer=None, uRegularizer=None, bRegularizer=None, with_peephole=True, bigdl_type="float"): - super(ConvLSTMPeephole, self).__init__(None, bigdl_type, input_size, output_size, kernel_i, kernel_c, stride, - wRegularizer, uRegularizer, bRegularizer, with_peephole) + super(ConvLSTMPeephole2D, self).__init__(None, bigdl_type, input_size, output_size, kernel_i, kernel_c, stride, + wRegularizer, uRegularizer, bRegularizer, with_peephole) + + +class ConvLSTMPeephole3D(Layer): + ''' + + :param input_size: number of input planes in the image given into forward() + :param output_size: number of output planes the convolution layer will produce + :param kernel_i Convolutional filter size to convolve input + :param kernel_c Convolutional filter size to convolve cell + :param stride The step of the convolution + :param wRegularizer: instance of [[Regularizer]](eg. L1 or L2 regularization), applied to the input weights matrices + :param uRegularizer: instance [[Regularizer]](eg. L1 or L2 regularization), applied to the recurrent weights matrices + :param bRegularizer: instance of [[Regularizer]]applied to the bias. + :param with_peephole: whether use last cell status control a gate. + + >>> convlstm = ConvLSTMPeephole3D(4, 3, 3, 3, 1, L1Regularizer(0.5), L1Regularizer(0.5), L1Regularizer(0.5)) + creating: createL1Regularizer + creating: createL1Regularizer + creating: createL1Regularizer + creating: createConvLSTMPeephole3D + ''' + + def __init__(self, input_size, output_size, kernel_i, kernel_c, stride, wRegularizer=None, uRegularizer=None, + bRegularizer=None, with_peephole=True, bigdl_type="float"): + super(ConvLSTMPeephole3D, self).__init__(None, bigdl_type, input_size, output_size, kernel_i, kernel_c, stride, + wRegularizer, uRegularizer, bRegularizer, with_peephole) def _test(): import doctest diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Cell.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Cell.scala index e1fbc432a08..6b260cfa326 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Cell.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Cell.scala @@ -22,6 +22,7 @@ import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric import com.intel.analytics.bigdl.utils.{T, Table} +import scala.collection.mutable import scala.reflect.ClassTag /** @@ -81,13 +82,13 @@ abstract class Cell[T : ClassTag]( * and recursively intialize all the tensors in the Table. * * @param hidden - * @param size batchSize + * @param batchSize batchSize * @return */ - def hidResize(hidden: Activity, size: Int, rows: Int = 1, columns: Int = 1): Activity = { + def hidResize(hidden: Activity, batchSize: Int, imageSize: Array[Int] = null): Activity = { if (hidden == null) { if (hiddensShape.length == 1) { - hidResize(Tensor[T](), size, rows, columns) + hidResize(Tensor[T](), batchSize) } else { val _hidden = T() var i = 1 @@ -95,25 +96,29 @@ abstract class Cell[T : ClassTag]( _hidden(i) = Tensor[T]() i += 1 } - hidResize(_hidden, size, rows, columns) + hidResize(_hidden, batchSize, imageSize) } } else { if (hidden.isInstanceOf[Tensor[T]]) { require(hidden.isInstanceOf[Tensor[T]], "Cell: hidden should be a Tensor") - hidden.toTensor.resize(size, hiddensShape(0)) + hidden.toTensor.resize(batchSize, hiddensShape(0)) } else { require(hidden.isInstanceOf[Table], "Cell: hidden should be a Table") var i = 1 - if (rows== 1 && columns==1) { + if (null == imageSize) { while (i <= hidden.toTable.length()) { - hidden.toTable[Tensor[T]](i).resize(size, hiddensShape(i - 1)) + hidden.toTable[Tensor[T]](i).resize(batchSize, hiddensShape(i - 1)) i += 1 } } else { + val sizes = new Array[Int](imageSize.length + 2) + sizes(0) = batchSize + Array.copy(imageSize, 0, sizes, 2, imageSize.size) while (i <= hidden.toTable.length()) { - hidden.toTable[Tensor[T]](i).resize(size, hiddensShape(i - 1), rows, columns) + sizes(1) = hiddensShape(i - 1) + hidden.toTable[Tensor[T]](i).resize(sizes) i += 1 } } diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole2D.scala similarity index 95% rename from spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole.scala rename to spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole2D.scala index 621db2fa6a8..51e991dd15d 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole2D.scala @@ -42,7 +42,7 @@ import scala.reflect.ClassTag applied to the bias. * @param withPeephole: whether use last cell status control a gate. */ -class ConvLSTMPeephole[T : ClassTag] ( +class ConvLSTMPeephole2D[T : ClassTag]( val inputSize: Int, val outputSize: Int, val kernelI: Int, @@ -202,10 +202,10 @@ class ConvLSTMPeephole[T : ClassTag] ( convlstm } - override def canEqual(other: Any): Boolean = other.isInstanceOf[ConvLSTMPeephole[T]] + override def canEqual(other: Any): Boolean = other.isInstanceOf[ConvLSTMPeephole2D[T]] override def equals(other: Any): Boolean = other match { - case that: ConvLSTMPeephole[T] => + case that: ConvLSTMPeephole2D[T] => super.equals(that) && (that canEqual this) && inputSize == that.inputSize && @@ -227,11 +227,11 @@ class ConvLSTMPeephole[T : ClassTag] ( cell.reset() } - override def toString: String = s"ConvLSTMPeephole($inputSize, $outputSize," + + override def toString: String = s"ConvLSTMPeephole2D($inputSize, $outputSize," + s"$kernelI, $kernelC, $stride)" } -object ConvLSTMPeephole { +object ConvLSTMPeephole2D { def apply[@specialized(Float, Double) T: ClassTag]( inputSize: Int, outputSize: Int, @@ -242,8 +242,8 @@ object ConvLSTMPeephole { uRegularizer: Regularizer[T] = null, bRegularizer: Regularizer[T] = null, withPeephole: Boolean = true - )(implicit ev: TensorNumeric[T]): ConvLSTMPeephole[T] = { - new ConvLSTMPeephole[T](inputSize, outputSize, kernelI, kernelC, stride, + )(implicit ev: TensorNumeric[T]): ConvLSTMPeephole2D[T] = { + new ConvLSTMPeephole2D[T](inputSize, outputSize, kernelI, kernelC, stride, wRegularizer, uRegularizer, bRegularizer, withPeephole) } } diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole3D.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole3D.scala new file mode 100644 index 00000000000..4180a7f056a --- /dev/null +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole3D.scala @@ -0,0 +1,234 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.bigdl.nn + +import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity} +import com.intel.analytics.bigdl.optim.Regularizer +import com.intel.analytics.bigdl.tensor.Tensor +import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric +import com.intel.analytics.bigdl.utils.{T, Table} + +import scala.reflect.ClassTag + +/** + * Convolution Long Short Term Memory architecture with peephole. + * Ref. A.: https://arxiv.org/abs/1506.04214 (blueprint for this module) + * B. https://github.com/viorik/ConvLSTM + * + * @param inputSize number of input planes in the image given into forward() + * @param outputSize number of output planes the convolution layer will produce + * @param kernelI Convolutional filter size to convolve input + * @param kernelC Convolutional filter size to convolve cell + * @param stride The step of the convolution + * @param wRegularizer: instance of [[Regularizer]] + (eg. L1 or L2 regularization), applied to the input weights matrices. + * @param uRegularizer: instance [[Regularizer]] + (eg. L1 or L2 regularization), applied to the recurrent weights matrices. + * @param bRegularizer: instance of [[Regularizer]] + applied to the bias. + * @param withPeephole: whether use last cell status control a gate. + */ +class ConvLSTMPeephole3D[T : ClassTag]( + val inputSize: Int, + val outputSize: Int, + val kernelI: Int, + val kernelC: Int, + val stride: Int, + var wRegularizer: Regularizer[T] = null, + var uRegularizer: Regularizer[T] = null, + var bRegularizer: Regularizer[T] = null, + val withPeephole: Boolean = true +)(implicit ev: TensorNumeric[T]) + extends Cell[T]( + hiddensShape = Array(outputSize, outputSize), + regularizers = Array(wRegularizer, uRegularizer, bRegularizer) + ) { + var inputGate: Sequential[T] = _ + var forgetGate: Sequential[T] = _ + var outputGate: Sequential[T] = _ + var hiddenLayer: Sequential[T] = _ + var cellLayer: Sequential[T] = _ + + override var cell: AbstractModule[Activity, Activity, T] = buildConvLSTM() + + def buildGate(): Sequential[T] = { + val i2g = Sequential() + .add(Contiguous()) + .add(VolumetricConvolution(inputSize, outputSize, kernelI, kernelI, kernelI, + stride, stride, stride, kernelI/2, kernelI/2, kernelI/2)) + val h2g = VolumetricConvolution(outputSize, outputSize, kernelC, kernelC, kernelC, + stride, stride, stride, kernelC/2, kernelC/2, kernelC/2, withBias = false) + + val gate = Sequential() + if (withPeephole) { + gate + .add(ParallelTable() + .add(i2g) + .add(h2g) + .add(CMul(Array(1, outputSize, 1, 1, 1)))) + } else { + gate.add(NarrowTable(1, 2)) + gate + .add(ParallelTable() + .add(i2g) + .add(h2g)) + } + + gate.add(CAddTable()) + .add(Sigmoid()) + } + + def buildInputGate(): Sequential[T] = { + inputGate = buildGate() + inputGate + } + + def buildForgetGate(): Sequential[T] = { + forgetGate = buildGate() + forgetGate + } + + def buildOutputGate(): Sequential[T] = { + outputGate = buildGate() + outputGate + } + + def buildHidden(): Sequential[T] = { + val hidden = Sequential() + .add(NarrowTable(1, 2)) + + val i2h = Sequential() + .add(Contiguous()) + .add(VolumetricConvolution(inputSize, outputSize, kernelI, kernelI, kernelI, + stride, stride, stride, kernelI/2, kernelI/2, kernelI/2)) + val h2h = VolumetricConvolution(outputSize, outputSize, kernelC, kernelC, kernelC, + stride, stride, stride, kernelC/2, kernelC/2, kernelC/2, withBias = false) + + hidden + .add(ParallelTable() + .add(i2h) + .add(h2h)) + .add(CAddTable()) + .add(Tanh()) + + this.hiddenLayer = hidden + hidden + } + + def buildCell(): Sequential[T] = { + buildInputGate() + buildForgetGate() + buildHidden() + + val forgetLayer = Sequential() + .add(ConcatTable() + .add(forgetGate) + .add(SelectTable(3))) + .add(CMulTable()) + + val inputLayer = Sequential() + .add(ConcatTable() + .add(inputGate) + .add(hiddenLayer)) + .add(CMulTable()) + + val cellLayer = Sequential() + .add(ConcatTable() + .add(forgetLayer) + .add(inputLayer)) + .add(CAddTable()) + + this.cellLayer = cellLayer + cellLayer + } + + def buildConvLSTM(): Sequential[T] = { + buildCell() + buildOutputGate() + + val convlstm = Sequential() + .add(FlattenTable()) + .add(ConcatTable() + .add(NarrowTable(1, 2)) + .add(cellLayer)) + .add(FlattenTable()) + + .add(ConcatTable() + .add(Sequential() + .add(ConcatTable() + .add(outputGate) + .add(Sequential() + .add(SelectTable(3)) + .add(Tanh()))) + .add(CMulTable()) + .add(Contiguous())) + .add(SelectTable(3))) + + .add(ConcatTable() + .add(SelectTable(1)) + .add(Identity())) + + cell = convlstm + convlstm + } + + override def canEqual(other: Any): Boolean = other.isInstanceOf[ConvLSTMPeephole2D[T]] + + override def equals(other: Any): Boolean = other match { + case that: ConvLSTMPeephole2D[T] => + super.equals(that) && + (that canEqual this) && + inputSize == that.inputSize && + outputSize == that.outputSize && + kernelI == that.kernelI && + kernelC == that.kernelC && + stride == that.stride + + case _ => false + } + + override def hashCode(): Int = { + val state = Seq(super.hashCode(), inputSize, outputSize, kernelI, kernelC, stride) + state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) + } + + override def reset(): Unit = { + super.reset() + cell.reset() + } + + override def toString: String = s"ConvLSTMPeephole3D($inputSize, $outputSize," + + s"$kernelI, $kernelC, $stride)" +} + +object ConvLSTMPeephole3D { + def apply[@specialized(Float, Double) T: ClassTag]( + inputSize: Int, + outputSize: Int, + kernelI: Int, + kernelC: Int, + stride: Int = 1, + wRegularizer: Regularizer[T] = null, + uRegularizer: Regularizer[T] = null, + bRegularizer: Regularizer[T] = null, + withPeephole: Boolean = true + )(implicit ev: TensorNumeric[T]): ConvLSTMPeephole3D[T] = { + new ConvLSTMPeephole3D[T](inputSize, outputSize, kernelI, kernelC, stride, + wRegularizer, uRegularizer, bRegularizer, withPeephole) + } +} + diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Recurrent.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Recurrent.scala index ece6770751c..a6b327c5aa4 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Recurrent.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Recurrent.scala @@ -81,12 +81,13 @@ class Recurrent[T : ClassTag]() /** * Clone N models; N depends on the time dimension of the input - * @param times - * @param batchSize - * @param hiddenSize + * @param sizes, the first element is batchSize, the second is times, the third is hiddensize + * the left is size of images */ - private def extend(times: Int, batchSize: Int, hiddenSize: Int, - rows: Int = 1, columns: Int = 1): Unit = { + private def extend(sizes: Array[Int]): Unit = { + val times = sizes(1) + val batchSize = sizes(0) + val imageSize = sizes.drop(3) if (hidden == null) { require((preTopology == null && modules.length == 1) || (topology != null && preTopology != null && modules.length == 2), @@ -98,7 +99,7 @@ class Recurrent[T : ClassTag]() val cell = cells.head // The cell will help initialize or resize the hidden variable. - hidden = cell.hidResize(hidden = null, size = batchSize, rows, columns) + hidden = cell.hidResize(hidden = null, batchSize = batchSize, imageSize) /* * Since the gradHidden is only used as an empty Tensor or Table during @@ -107,7 +108,7 @@ class Recurrent[T : ClassTag]() */ gradHidden = hidden } else { - cells.head.hidResize(hidden = hidden, size = batchSize, rows, columns) + cells.head.hidResize(hidden = hidden, batchSize = batchSize, imageSize) gradHidden = hidden } var t = cells.length @@ -199,8 +200,8 @@ class Recurrent[T : ClassTag]() } override def updateOutput(input: Tensor[T]): Tensor[T] = { - require(input.dim == 3 || input.dim == 5, - "Recurrent: input should be a 3D or 5D Tensor, e.g [batch, times, nDim], " + + require(input.dim == 3 || input.dim == 5 || input.dim == 6, + "Recurrent: input should be a 3D/5D/6D Tensor, e.g [batch, times, nDim], " + s"current input.dim = ${input.dim}") batchSize = input.size(batchDim) @@ -229,15 +230,11 @@ class Recurrent[T : ClassTag]() } val hiddenSize = topology.hiddensShape(0) - if (input.dim() == 3) { - output.resize(batchSize, times, hiddenSize) - // Clone N modules along the sequence dimension. - extend(times, batchSize, hiddenSize) - } else if (input.dim() == 5) { - output.resize(batchSize, times, hiddenSize, input.size(4), input.size(5)) - // Clone N modules along the sequence dimension. - extend(times, batchSize, hiddenSize, input.size(4), input.size(5)) - } + val outputSize = input.size() + outputSize(2) = hiddenSize + output.resize(outputSize) + // Clone N modules along the sequence dimension. + extend(outputSize) /** * currentInput forms a T() type. It contains two elements, hidden and input. diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala index c9037fcf518..a6a89e3b7bd 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala @@ -251,7 +251,7 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab Recurrent[T]() } - def createConvLSTMPeephole( + def createConvLSTMPeephole2D( inputSize: Int, outputSize: Int, kernelI: Int, @@ -260,8 +260,22 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab wRegularizer: Regularizer[T] = null, uRegularizer: Regularizer[T] = null, bRegularizer: Regularizer[T] = null, - withPeephole: Boolean = true): ConvLSTMPeephole[T] = { - ConvLSTMPeephole[T](inputSize, outputSize, kernelI, kernelC, stride, + withPeephole: Boolean = true): ConvLSTMPeephole2D[T] = { + ConvLSTMPeephole2D[T](inputSize, outputSize, kernelI, kernelC, stride, + wRegularizer, uRegularizer, bRegularizer, withPeephole) + } + + def createConvLSTMPeephole3D( + inputSize: Int, + outputSize: Int, + kernelI: Int, + kernelC: Int, + stride: Int = 1, + wRegularizer: Regularizer[T] = null, + uRegularizer: Regularizer[T] = null, + bRegularizer: Regularizer[T] = null, + withPeephole: Boolean = true): ConvLSTMPeephole3D[T] = { + ConvLSTMPeephole3D[T](inputSize, outputSize, kernelI, kernelC, stride, wRegularizer, uRegularizer, bRegularizer, withPeephole) } diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/CellSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/CellSpec.scala index 5e05166c1a4..97155736502 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/CellSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/CellSpec.scala @@ -46,7 +46,7 @@ class CellSpec extends FlatSpec with Matchers { "A Cell" should "hidResize correctly" in { val cell = new CellUnit[Double](4) - val hidden = cell.hidResize(hidden = null, size = 5) + val hidden = cell.hidResize(hidden = null, batchSize = 5) hidden.isInstanceOf[Table] should be (true) var i = 1 diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeepholeSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole2DSpec.scala similarity index 98% rename from spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeepholeSpec.scala rename to spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole2DSpec.scala index ab8b33bc3de..8d437680bf6 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeepholeSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole2DSpec.scala @@ -25,9 +25,9 @@ import scala.collection.mutable.ArrayBuffer import scala.math._ @com.intel.analytics.bigdl.tags.Parallel -class ConvLSTMPeepholeSpec extends FlatSpec with BeforeAndAfter with Matchers { +class ConvLSTMPeephole2DSpec extends FlatSpec with BeforeAndAfter with Matchers { - "A convLstm" should " work in BatchMode" in { + "A convLstm2d" should " work in BatchMode" in { val hiddenSize = 5 val inputSize = 3 val seqLength = 4 @@ -37,13 +37,13 @@ class ConvLSTMPeepholeSpec extends FlatSpec with BeforeAndAfter with Matchers { val rec = Recurrent[Double]() val model = Sequential[Double]() .add(rec - .add(ConvLSTMPeephole[Double]( + .add(ConvLSTMPeephole2D[Double]( inputSize, hiddenSize, kernalW, kernalH, 1, withPeephole = false))) - .add(View(hiddenSize * kernalH * kernalW)) +// .add(View(hiddenSize * kernalH * kernalW)) val input = Tensor[Double](batchSize, seqLength, inputSize, kernalW, kernalH).rand val output = model.forward(input).toTensor[Double] @@ -53,7 +53,7 @@ class ConvLSTMPeepholeSpec extends FlatSpec with BeforeAndAfter with Matchers { } } - "A ConvLSTMPeepwhole " should "generate corrent output" in { + "A ConvLSTMPeepwhole2D " should "generate corrent output" in { val hiddenSize = 5 val inputSize = 3 val seqLength = 4 @@ -113,7 +113,7 @@ class ConvLSTMPeepholeSpec extends FlatSpec with BeforeAndAfter with Matchers { val rec = Recurrent[Double]() val model = Sequential[Double]() .add(rec - .add(ConvLSTMPeephole[Double](inputSize, hiddenSize, 3, 3, 1, withPeephole = false))) + .add(ConvLSTMPeephole2D[Double](inputSize, hiddenSize, 3, 3, 1, withPeephole = false))) val weightData = Array( 0.1323664, 0.11453647, 0.08062653, 0.12153825, 0.09627097, 0.09425588, -0.12831208, @@ -441,7 +441,7 @@ class ConvLSTMPeepholeSpec extends FlatSpec with BeforeAndAfter with Matchers { }) } - "A ConvLSTMPeepwhole " should "generate corrent output2 when batch != 1" in { + "A ConvLSTMPeepwhole2D " should "generate corrent output2 when batch != 1" in { val hiddenSize = 4 val inputSize = 2 val seqLength = 2 @@ -500,7 +500,7 @@ class ConvLSTMPeepholeSpec extends FlatSpec with BeforeAndAfter with Matchers { val rec = Recurrent[Double]() val model = Sequential[Double]() .add(rec - .add(ConvLSTMPeephole[Double](inputSize, hiddenSize, 3, 3, 1, withPeephole = false))) + .add(ConvLSTMPeephole2D[Double](inputSize, hiddenSize, 3, 3, 1, withPeephole = false))) val weightData = Array( -0.0697708, 0.187022, 0.08511595, 0.096392, 0.004365, -0.181258, 0.0446674, diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole3DSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole3DSpec.scala new file mode 100644 index 00000000000..84b1b8921ec --- /dev/null +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConvLSTMPeephole3DSpec.scala @@ -0,0 +1,53 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.bigdl.torch + +import com.intel.analytics.bigdl.nn._ +import com.intel.analytics.bigdl.utils._ +import com.intel.analytics.bigdl.tensor.{Storage, Tensor} +import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} + +import scala.collection.mutable.ArrayBuffer +import scala.math._ + +@com.intel.analytics.bigdl.tags.Parallel +class ConvLSTMPeephole3DSpec extends FlatSpec with BeforeAndAfter with Matchers { + + "A ConvLSTMPeepwhole3D" should " work in BatchMode" in { + val hiddenSize = 5 + val inputSize = 3 + val seqLength = 4 + val batchSize = 2 + val kernalW = 3 + val kernalH = 3 + val rec = Recurrent[Double]() + val model = Sequential[Double]() + .add(rec + .add(ConvLSTMPeephole3D[Double]( + inputSize, + hiddenSize, + kernalW, kernalH, + 1, withPeephole = true))) + + val input = Tensor[Double](batchSize, seqLength, inputSize, 3, 3, 3).rand + + for (i <- 1 to 3) { + val output = model.forward(input) + model.backward(input, output) + } + } +}