diff --git a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc index b4e5d0ae58a416..ff3972f1ff28ea 100644 --- a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc @@ -168,60 +168,29 @@ struct TransposeUsingTile { } // namespace internal // Transpose kernel specialized for GPU Device. +#define HANDLE_DIM(DIM) \ + case DIM: \ + internal::TransposeUsingEigen(d, in, perm, conjugate, \ + out); \ + break + template struct Transpose { static void run(const GPUDevice& d, const Tensor& in, const gtl::ArraySlice perm, Tensor* out) { + if (in.dims() < 2) return; + if (internal::TransposeUsingTile::run(d, in, perm, out)) { + return; + } + switch (in.dims()) { - case 2: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 3: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 4: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 5: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 6: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 7: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 8: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; + HANDLE_DIM(2); + HANDLE_DIM(3); + HANDLE_DIM(4); + HANDLE_DIM(5); + HANDLE_DIM(6); + HANDLE_DIM(7); + HANDLE_DIM(8); default: internal::TransposeSimple(d, in, perm, out); break; @@ -229,6 +198,8 @@ struct Transpose { } }; +#undef HANDLE_DIM + template struct Transpose { static void run(const GPUDevice& d, const Tensor& in,