Skip to content

Commit

Permalink
Small cleanup of transpose_functor_gpu.cu.cc.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 236785253
  • Loading branch information
martinwicke authored and tensorflower-gardener committed Mar 5, 2019
1 parent 4484797 commit 98c4548
Showing 1 changed file with 20 additions and 49 deletions.
69 changes: 20 additions & 49 deletions tensorflow/core/kernels/transpose_functor_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,67 +168,38 @@ struct TransposeUsingTile<complex128, conjugate> {
} // namespace internal

// Transpose kernel specialized for GPU Device.
#define HANDLE_DIM(DIM) \
case DIM: \
internal::TransposeUsingEigen<GPUDevice, T, DIM>(d, in, perm, conjugate, \
out); \
break

template <typename T, bool conjugate>
struct Transpose<GPUDevice, T, conjugate> {
static void run(const GPUDevice& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out) {
if (in.dims() < 2) return;
if (internal::TransposeUsingTile<T, conjugate>::run(d, in, perm, out)) {
return;
}

switch (in.dims()) {
case 2:
if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
out)) {
internal::TransposeUsingEigen<GPUDevice, T, 2>(d, in, perm, conjugate,
out);
}
break;
case 3:
if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
out)) {
internal::TransposeUsingEigen<GPUDevice, T, 3>(d, in, perm, conjugate,
out);
}
break;
case 4:
if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
out)) {
internal::TransposeUsingEigen<GPUDevice, T, 4>(d, in, perm, conjugate,
out);
}
break;
case 5:
if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
out)) {
internal::TransposeUsingEigen<GPUDevice, T, 5>(d, in, perm, conjugate,
out);
}
break;
case 6:
if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
out)) {
internal::TransposeUsingEigen<GPUDevice, T, 6>(d, in, perm, conjugate,
out);
}
break;
case 7:
if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
out)) {
internal::TransposeUsingEigen<GPUDevice, T, 7>(d, in, perm, conjugate,
out);
}
break;
case 8:
if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
out)) {
internal::TransposeUsingEigen<GPUDevice, T, 8>(d, in, perm, conjugate,
out);
}
break;
HANDLE_DIM(2);
HANDLE_DIM(3);
HANDLE_DIM(4);
HANDLE_DIM(5);
HANDLE_DIM(6);
HANDLE_DIM(7);
HANDLE_DIM(8);
default:
internal::TransposeSimple<T, conjugate>(d, in, perm, out);
break;
}
}
};

#undef HANDLE_DIM

template <bool conjugate>
struct Transpose<GPUDevice, string, conjugate> {
static void run(const GPUDevice& d, const Tensor& in,
Expand Down

0 comments on commit 98c4548

Please sign in to comment.