diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 4c172be6f0d..e931c6995e0 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1874,16 +1874,17 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): return tf.pad(x, pattern) -def stack(x): +def stack(x, axis=0): """Stacks a list of rank `R` tensors into a rank `R+1` tensor. # Arguments - x: Tensor or variable. + x: List of tensors. + axis: Axis along which to perform stacking. # Returns A tensor. """ - return tf.stack(x) + return tf.stack(x, axis=axis) def one_hot(indices, num_classes): diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 75c6488e5da..72099a906b8 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -985,8 +985,8 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): return T.set_subtensor(output[indices], x) -def stack(x): - return T.stack(*x) +def stack(x, axis=0): + return T.stack(x, axis=axis) def one_hot(indices, num_classes):