Skip to content

Commit

Permalink
Support non-contiguous tensors for unary ops (#6119)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored and apaszke committed Apr 27, 2018
1 parent a6bfa16 commit ae35e0e
Show file tree
Hide file tree
Showing 13 changed files with 681 additions and 383 deletions.
599 changes: 340 additions & 259 deletions aten/src/ATen/CPUApplyUtils.h

Large diffs are not rendered by default.

10 changes: 0 additions & 10 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,6 @@
- Int
- Short
backends:
- CPU
- CUDA
variants:
- method
Expand Down Expand Up @@ -1161,7 +1160,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
Expand Down Expand Up @@ -1354,7 +1352,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
Expand Down Expand Up @@ -1400,7 +1397,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
Expand Down Expand Up @@ -1475,7 +1471,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
Expand Down Expand Up @@ -1695,7 +1690,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
Expand Down Expand Up @@ -1741,7 +1735,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
Expand All @@ -1758,7 +1751,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
Expand All @@ -1775,7 +1767,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
Expand All @@ -1792,7 +1783,6 @@
types:
- floating_point
backends:
- CPU
- CUDA
variants:
- method
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/Parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace internal {
// for a certain number of workers. If there are multiple threads making
// a request at the size of the maximum number of threads, they will
// be allocated a number proportional to the other requests.
void init_tbb_num_threads();
AT_API void init_tbb_num_threads();
// This parameter is heuristically chosen to determine the minimum number of
// work that warrants paralellism. For example, when summing an array, it is
// deemed inefficient to parallelise over arrays shorter than 32768. Further,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/TensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void checkDefined(CheckedFrom c, const TensorArg& t);
void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);

// FixMe: does TensorArg slow things down?
void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> t, at::Backend backend);
AT_API void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> t, at::Backend backend);

// Methods for getting data_ptr if tensor is defined
void * maybe_data_ptr(const Tensor& tensor);
Expand Down
130 changes: 94 additions & 36 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,106 @@
#include "ATen/ExpandUtils.h"
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtils.h"
#include "cpu/UnaryOpsKernel.h"

#include "ATen/CPUApplyUtils.h"
#include "ATen/Parallel.h"
#include "ATen/native/cpu/UnaryOpsKernel.h"

#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include <vector>

#include <map>

namespace at { namespace native {

#define IMPLEMENT_UNARY_OP(op) \
Tensor op(const Tensor& self) { \
Tensor result = self.type().tensor(); \
return at::op ## _out(result, self); \
} \
Tensor& op##_(Tensor& self) { \
return at::op ## _out(self, self); \
} \
Tensor& _ ## op ## _out_cuda(Tensor& result, const Tensor& self) { \
return at::_ ## op ## _out(result, self); \
} \
Tensor& _ ## op ## _out_cpu(Tensor& result, const Tensor& self) { \
if (result.is_contiguous() && self.is_contiguous()) { \
result.resize_(self.sizes()); \
if (result.numel() > 0) { \
op ## Impl(result, self); \
} \
return result; \
} \
return at::_ ## op ## _out(result, self); \
}
namespace at {
namespace native {

#define IMPLEMENT_UNARY_OP_PREQUEL(op) \
Tensor op(const Tensor& self) { \
Tensor result = self.type().tensor(); \
return at::op##_out(result, self); \
} \
Tensor& _##op##__cuda(Tensor& self) { \
return at::_##op##_out(self, self); \
} \
Tensor& _##op##_out_cuda(Tensor& result, const Tensor& self) { \
return at::_##op##_out(result, self); \
}

#define IMPLEMENT_UNARY_OP_FLOAT_CMATH(op) \
Tensor& _##op##__cpu(Tensor& self_) { \
if (self_.numel() > 0) { \
Tensor self = sort_strides(self_); \
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
CPU_tensor_parallel_apply1<scalar_t>( \
self, [](scalar_t& y) { y = std::op(y); }); \
}); \
} \
return self_; \
} \
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
result.resize_(self.sizes()); \
if (result.numel() > 0) { \
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
CPU_tensor_parallel_apply2<scalar_t, scalar_t>( \
result, self, [](scalar_t& y, scalar_t& x) { y = std::op(x); }); \
}); \
} \
return result; \
}

IMPLEMENT_UNARY_OP(abs)
IMPLEMENT_UNARY_OP(ceil)
IMPLEMENT_UNARY_OP(cos)
IMPLEMENT_UNARY_OP(exp)
IMPLEMENT_UNARY_OP(floor)
IMPLEMENT_UNARY_OP(log)
IMPLEMENT_UNARY_OP(round)
IMPLEMENT_UNARY_OP(sin)
IMPLEMENT_UNARY_OP(sqrt)
IMPLEMENT_UNARY_OP(trunc)

}} // namespace at::native
#define IMPLEMENT_UNARY_OP_VEC(op) \
Tensor& _##op##__cpu(Tensor& self_) { \
if (self_.numel() > 0) { \
Tensor self = sort_strides(self_); \
if (self.is_contiguous()) { \
op##Impl(self, self); \
} else { \
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
CPU_tensor_parallel_apply1<scalar_t>( \
self, [](scalar_t& y) { y = std::op(y); }); \
}); \
} \
} \
return self_; \
} \
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
result.resize_(self.sizes()); \
if (result.numel() > 0) { \
if (result.is_contiguous() && self.is_contiguous()) { \
op##Impl(result, self); \
} else { \
AT_DISPATCH_FLOATING_TYPES(self.type(), op, [&] { \
CPU_tensor_parallel_apply2<scalar_t, scalar_t>( \
result, self, [](scalar_t& y, scalar_t& x) { y = std::op(x); }); \
}); \
} \
} \
return result; \
}

IMPLEMENT_UNARY_OP_PREQUEL(abs)
IMPLEMENT_UNARY_OP_PREQUEL(ceil)
IMPLEMENT_UNARY_OP_PREQUEL(cos)
IMPLEMENT_UNARY_OP_PREQUEL(exp)
IMPLEMENT_UNARY_OP_PREQUEL(floor)
IMPLEMENT_UNARY_OP_PREQUEL(log)
IMPLEMENT_UNARY_OP_PREQUEL(round)
IMPLEMENT_UNARY_OP_PREQUEL(sin)
IMPLEMENT_UNARY_OP_PREQUEL(sqrt)
IMPLEMENT_UNARY_OP_PREQUEL(trunc)

IMPLEMENT_UNARY_OP_VEC(abs)
IMPLEMENT_UNARY_OP_VEC(ceil)
IMPLEMENT_UNARY_OP_FLOAT_CMATH(cos)
IMPLEMENT_UNARY_OP_FLOAT_CMATH(exp)
IMPLEMENT_UNARY_OP_VEC(floor)
IMPLEMENT_UNARY_OP_FLOAT_CMATH(log)
IMPLEMENT_UNARY_OP_VEC(round)
IMPLEMENT_UNARY_OP_FLOAT_CMATH(sin)
IMPLEMENT_UNARY_OP_VEC(sqrt)
IMPLEMENT_UNARY_OP_VEC(trunc)
}
} // namespace at
93 changes: 30 additions & 63 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
#include "ATen/cpu/vec256/vec256.h"
#include "ATen/native/cpu/CapabilityDispatch.h"

namespace at { namespace native { namespace {
namespace at { namespace native {
namespace {

using namespace vec256;

template <typename scalar_t, typename F>
static void unary_kernel(scalar_t* arr_out, const scalar_t* arr_in, int64_t size, F func) {
static void
unary_kernel(scalar_t* arr_out, const scalar_t* arr_in, int64_t size, F func) {
using Vec = Vec256<scalar_t>;
int64_t size_rounded = size - (size % Vec::size);
int64_t k = 0;
Expand Down Expand Up @@ -52,94 +54,59 @@ static void parallel_apply(Tensor& result, const Tensor& self, F f) {

static void abs_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_ALL_TYPES(self.type(), "abs", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.abs();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.abs(); }); });
}

static void ceil_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "ceil", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.ceil();
});
});
}

static void cos_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "cos", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.cos();
});
});
}

static void exp_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "exp", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.exp();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.ceil(); }); });
}

static void floor_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "floor", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.floor();
});
});
}

static void log_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "log", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.log();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.floor(); }); });
}

static void round_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "round", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.round();
});
});
}

static void sin_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "sin", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.sin();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.round(); }); });
}

static void sqrt_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "sqrt", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.sqrt();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.sqrt(); }); });
}

static void trunc_kernel(Tensor& result, const Tensor& self) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "trunc", [&] {
parallel_apply<scalar_t>(result, self, [](const Vec256<scalar_t>& x) {
return x.trunc();
});
});
parallel_apply<scalar_t>(
result,
self,
[](const Vec256<scalar_t>& x) { return x.trunc(); }); });
}

} // anonymous namespace
} // anonymous namespace


REGISTER_DISPATCH(absImpl, &abs_kernel);
REGISTER_DISPATCH(ceilImpl, &ceil_kernel);
REGISTER_DISPATCH(cosImpl, &cos_kernel);
REGISTER_DISPATCH(expImpl, &exp_kernel);
REGISTER_DISPATCH(floorImpl, &floor_kernel);
REGISTER_DISPATCH(logImpl, &log_kernel);
REGISTER_DISPATCH(roundImpl, &round_kernel);
REGISTER_DISPATCH(sinImpl, &sin_kernel);
REGISTER_DISPATCH(sqrtImpl, &sqrt_kernel);
REGISTER_DISPATCH(truncImpl, &trunc_kernel);

Expand Down
Loading

0 comments on commit ae35e0e

Please sign in to comment.