Skip to content

Commit

Permalink
Accept non-standard bools in more CUDA kernels
Browse files Browse the repository at this point in the history
This fixes all remaining CUDA kernels, except those using `cub` or
`thrust`, to accept boolean tensors with values other than 1 or 0.

I do this by using `c10::load` in more places, and also adding a
`load_vector` helper into `MemoryAccess.cuh` that does the same thing
for vectorized loads.

Pull Request resolved: pytorch#78957

Approved by: https://github.com/mruberry
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Jun 9, 2022
1 parent 4945c72 commit cd9e158
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 118 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t stri
std::index_sequence<INDEX...>) {
(void)strides;
(void)i;
return f(*(typename traits::template arg<INDEX>::type*)(data[INDEX] + i * strides[INDEX])...);
return f(c10::load<typename traits::template arg<INDEX>::type>(data[INDEX] + i * strides[INDEX])...);
}

template <typename func_t, typename index_t, typename traits = function_traits<func_t>>
Expand Down
24 changes: 20 additions & 4 deletions aten/src/ATen/native/cuda/MemoryAccess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ struct multi_outputs_store_helper {
struct LoadWithoutCast {
template<typename scalar_t>
__device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
return *(reinterpret_cast<scalar_t *>(base_ptr) + offset);
return c10::load(reinterpret_cast<scalar_t *>(base_ptr) + offset);
}
};

Expand Down Expand Up @@ -161,6 +161,24 @@ struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
scalar_t val[vec_size];
};

template <int vec_size, typename scalar_t>
__device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
using vec_t = aligned_vector<scalar_t, vec_size>;
auto *from = reinterpret_cast<const vec_t *>(base_ptr);
return from[offset];
}

template <int vec_size>
__device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
// See NOTE [Loading boolean values]
auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
aligned_vector<bool, vec_size> ret;
for (int i = 0; i < vec_size; ++i) {
ret.val[i] = bool(tmp.val[i]);
}
return ret;
}

namespace policies {

// Assumption:
Expand Down Expand Up @@ -236,13 +254,11 @@ struct vectorized {

template<typename accessor_t, typename scalar_t>
__device__ inline void load_single_arg(accessor_t to, scalar_t *from) {
using vec_t = aligned_vector<scalar_t, vec_size>;
vec_t *from_ = reinterpret_cast<vec_t *>(from);
int thread_idx = threadIdx.x;
#pragma unroll
for (int i = 0; i < loop_size; i++) {
int index = thread_idx + i * num_threads();
vec_t v = from_[index];
auto v = load_vector<vec_size>(from, index);
#pragma unroll
for (int j = 0; j < vec_size; j++) {
to(vec_size * i + j) = v.val[j];
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/ROCmLoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ template <typename traits, typename func_t, typename index_t, size_t... INDEX>
C10_HOST_DEVICE typename traits::result_type
invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i,
std::index_sequence<INDEX...>) {
return f(*(typename traits::template arg<INDEX>::type*)(data[INDEX] + i * strides[INDEX])...);
return f(c10::load<typename traits::template arg<INDEX>::type>(data[INDEX] + i * strides[INDEX])...);
}

template <typename func_t, typename index_t, typename traits = function_traits<func_t>>
Expand Down Expand Up @@ -257,7 +257,7 @@ __global__ void elementwise_kernel(int N, func_t f, array_t data) {
if (idx + num_threads() * i < N) {
#pragma unroll
for (int j = 0; j < arity; j++) {
args[i][j] = *(args_base[j] + i * num_threads());
args[i][j] = c10::load(args_base[j] + i * num_threads());
}
}
}
Expand Down
20 changes: 9 additions & 11 deletions aten/src/ATen/native/cuda/Reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ struct ReduceOp {
data -= shift;
end += shift;
if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
value = ops.reduce(value, data[threadIdx.x], threadIdx.x - shift);
value = ops.reduce(value, c10::load(data + threadIdx.x), threadIdx.x - shift);
}
end -= align_elements;
data += align_elements;
Expand All @@ -531,15 +531,11 @@ struct ReduceOp {
value_list[i] = ident;
}

scalar_t values[input_vec_size];

load_t *values_vector = reinterpret_cast<load_t*>(&values[0]);

while (idx * input_vec_size + input_vec_size - 1 < end) {
*values_vector = reinterpret_cast<const load_t*>(data)[idx];
const auto values_vec = memory::load_vector<input_vec_size>(data, idx);
#pragma unroll
for (index_t i = 0; i < input_vec_size; i++) {
value_list[i] = ops.reduce(value_list[i], values[i], shift + idx * input_vec_size + i);
value_list[i] = ops.reduce(value_list[i], values_vec.val[i], shift + idx * input_vec_size + i);
}
idx += stride;
}
Expand All @@ -549,7 +545,8 @@ struct ReduceOp {
if (config.should_reduce_tail()) {
int idx = tail_start + threadIdx.x;
if (idx < end) {
value_list[0] = ops.reduce(value_list[0], data[idx], idx + shift);
const auto value = c10::load(data + idx);
value_list[0] = ops.reduce(value_list[0], value, idx + shift);
}
}

Expand All @@ -569,7 +566,6 @@ struct ReduceOp {

using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
using load_t = at::native::memory::aligned_vector<scalar_t, output_vec_size>;
const load_t* data = reinterpret_cast<const load_t*>(data_);

// Multiple accumulators to remove dependency between unrolled loops.
arg_vec_t value_list[vt0];
Expand All @@ -587,7 +583,8 @@ struct ReduceOp {
while (idx + (vt0 - 1) * stride < end) {
#pragma unroll
for (index_t i = 0; i < vt0; i++) {
values[i] = data[calc(idx + i * stride) / output_vec_size];
const auto offset = calc(idx + i * stride) / output_vec_size;
values[i] = memory::load_vector<output_vec_size>(data_, offset);
}
#pragma unroll
for (index_t i = 0; i < vt0; i++) {
Expand All @@ -606,7 +603,8 @@ struct ReduceOp {
if (idx >= end) {
break;
}
values[i] = data[calc(idx) / output_vec_size];
const auto offset = calc(idx) / output_vec_size;
values[i] = memory::load_vector<output_vec_size>(data_, offset);
idx += stride;
}
idx = idx_;
Expand Down
12 changes: 7 additions & 5 deletions aten/src/ATen/native/cuda/ScanKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <c10/util/accumulate.h>
#include <c10/util/Load.h>

#include <ATen/cuda/cub.cuh>

Expand Down Expand Up @@ -61,15 +62,15 @@ __global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *se
int col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_self[col1];
row_buf[threadIdx.x] = c10::load(&row_self[col1]);
row_idx_buf[threadIdx.x] = col1;
} else {
row_buf[threadIdx.x] = init;
// No need to set the index here as the value in init will never be selected
}

if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] = row_self[col2];
row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]);
row_idx_buf[num_threads_x + threadIdx.x] = col2;
} else {
row_buf[num_threads_x + threadIdx.x] = init;
Expand Down Expand Up @@ -142,8 +143,9 @@ __global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scala
int64_t out_idx = 0;

for (auto col = decltype(row_size){0}; col < row_size; ++col) {
if(at::_isnan(*self) || (!at::_isnan(out) && binary_op(*self, out))) {
out = *self;
const auto val = c10::load(self);
if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) {
out = val;
out_idx = col;
}
*values = out;
Expand Down Expand Up @@ -267,7 +269,7 @@ __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_,
scalar_t acc = init;

for (uint32_t col = 0; col < row_size; ++col) {
acc = binary_op(acc, *src);
acc = binary_op(acc, c10::load(src));
*tgt = acc;

src += num_irows;
Expand Down
95 changes: 65 additions & 30 deletions aten/src/ATen/native/cuda/jit_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,38 @@ struct alignas(2) BFloat16 {
}
)ESCAPE";

// From c10/util/Load.h
const std::string load_support_literal = R"ESCAPE(
namespace c10 {
template <typename T>
struct LoadImpl {
__device__ static T apply(const void *src) {
return *reinterpret_cast<const T*>(src);
}
};
template <>
struct LoadImpl<bool> {
__device__ static bool apply(const void *src) {
static_assert(sizeof(bool) == sizeof(char), "");
return LoadImpl<char>::apply(src);
}
};
template <typename T>
__device__ T load(const void *src) {
return LoadImpl<T>::apply(src);
}
template <typename scalar_t>
__device__ scalar_t load(const scalar_t *src) {
return LoadImpl<scalar_t>::apply(src);
}
} // namespace c10
)ESCAPE";

// copy-pasted from c10/util/TypeCast.h and c10/core/DynamicCast.h
const std::string dynamic_cast_support_literal = R"ESCAPE(
Expand Down Expand Up @@ -280,30 +312,10 @@ const std::string dynamic_cast_support_literal = R"ESCAPE(
}
};
template <typename T>
struct LoadImpl {
__device__ static T apply(const void *src) {
return *reinterpret_cast<const T*>(src);
}
};
template <>
struct LoadImpl<bool> {
__device__ static bool apply(const void *src) {
static_assert(sizeof(bool) == sizeof(char), "");
return LoadImpl<char>::apply(src);
}
};
template <typename T>
__device__ T load(const void *src) {
return LoadImpl<T>::apply(src);
}
// Fetch a value with dynamic type src_type from ptr, and cast it to static type dest_t.
#define FETCH_AND_CAST_CASE(type, scalartype) \
case ScalarType::scalartype: \
return static_cast_with_inter_type<dest_t, type>::apply(load<type>(ptr));
return static_cast_with_inter_type<dest_t, type>::apply(c10::load<type>(ptr));
template<typename dest_t>
__device__ inline dest_t fetch_and_cast(const ScalarType src_type, const void *ptr) {
switch (src_type) {
Expand Down Expand Up @@ -364,7 +376,7 @@ const std::string no_dynamic_cast_support_literal = R"ESCAPE(
struct LoadWithoutCast {
template <typename scalar_t>
__device__ scalar_t load(char* base_ptr, uint32_t offset, int arg=0) {
return *(reinterpret_cast<scalar_t*>(base_ptr) + offset);
return c10::load(reinterpret_cast<scalar_t*>(base_ptr) + offset);
}
};
Expand Down Expand Up @@ -466,6 +478,7 @@ const std::string offset_calc_template = R"ESCAPE(

const std::string jit_code_template = R"ESCAPE(
${load_support}
${dynamic_casting_string}
Expand Down Expand Up @@ -529,9 +542,11 @@ const std::string jit_code_template = R"ESCAPE(

const std::string jit_vectorized_code_template = R"ESCAPE(
${load_support}
template <typename scalar_t>
__device__ __inline__ scalar_t load(char* base_ptr, uint32_t offset) {
return *(reinterpret_cast<scalar_t*>(base_ptr) + offset);
return c10::load(reinterpret_cast<scalar_t*>(base_ptr) + offset);
}
template<typename scalar_t>
Expand All @@ -545,6 +560,24 @@ const std::string jit_vectorized_code_template = R"ESCAPE(
scalar_t val[vec_size];
};
template <int vec_size, typename scalar_t>
__device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
using vec_t = aligned_vector<scalar_t, vec_size>;
auto *from = reinterpret_cast<const vec_t *>(base_ptr);
return from[offset];
}
template <int vec_size>
__device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
// See NOTE [Loading boolean values]
auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
aligned_vector<bool, vec_size> ret;
for (int i = 0; i < vec_size; ++i) {
ret.val[i] = bool(tmp.val[i]);
}
return ret;
}
${functor}
// TODO: setup grid-stride loop
Expand All @@ -556,6 +589,7 @@ const std::string jit_vectorized_code_template = R"ESCAPE(
${compute_type} scalar_val${extra_params}) //[${nInputs}+${nOutputs}],
{
constexpr int vec_size = ${vec_size};
using scalar_t = ${scalar_type};
int remaining = N - block_work_size * blockIdx.x;
auto thread_idx = threadIdx.x;
int idx = blockIdx.x;
Expand Down Expand Up @@ -591,11 +625,9 @@ const std::string jit_vectorized_code_template = R"ESCAPE(
} else {
static constexpr int loop_size = thread_work_size / vec_size;
//actual loading
using vec_t_input = aligned_vector<${scalar_type}, vec_size>;
${vector_inputs}
#pragma unroll
for (int i = 0; i<loop_size; i++){
vec_t_input v;
${load_vectorized_inputs}
thread_idx += num_threads;
}
Expand Down Expand Up @@ -846,6 +878,8 @@ std::string generate_code(
env.s("complex_half_body_string", "");
}

env.s("load_support", load_support_literal);

if (!vectorized) {
if (!dynamic_casting) {
env.s("loader", "LoadWithoutCast");
Expand Down Expand Up @@ -895,9 +929,9 @@ std::string generate_code(
std::stringstream vector_inputs;
for (const auto i : c10::irange(nInputs)){
auto i_string = std::to_string(i);
vector_inputs << "vec_t_input * vec" << i_string <<
" = reinterpret_cast<vec_t_input *>(data[" << i_string << "+" << nOutputs << "])" <<
" + block_work_size / vec_size * idx;\n";
vector_inputs << "auto * input" << i_string <<
" = reinterpret_cast<const scalar_t*>(data[" << i_string << "+" << nOutputs << "])" <<
" + block_work_size * idx;\n";
}
env.s("vector_inputs", vector_inputs.str());

Expand All @@ -913,10 +947,11 @@ std::string generate_code(
std::stringstream load_vectorized_inputs;
for (const auto i : c10::irange(nInputs)) {
auto i_string = std::to_string(i);
load_vectorized_inputs << "v = vec" << i_string << "[thread_idx];\n";
load_vectorized_inputs << "const auto vec" << i_string << " = load_vector<vec_size>("
<< "input" << i_string << ", thread_idx);\n";
load_vectorized_inputs << "#pragma unroll\n";
load_vectorized_inputs << "for (int j=0; j < vec_size; j++){\n";
load_vectorized_inputs << " arg" << i_string << "[vec_size * i + j] = v.val[j];\n";
load_vectorized_inputs << " arg" << i_string << "[vec_size * i + j] = vec" << i_string << ".val[j];\n";
load_vectorized_inputs << "}\n";
}
env.s("load_vectorized_inputs", load_vectorized_inputs.str());
Expand Down
6 changes: 6 additions & 0 deletions c10/util/Load.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ template <>
struct LoadImpl<bool> {
C10_HOST_DEVICE static bool apply(const void* src) {
static_assert(sizeof(bool) == sizeof(char), "");
// NOTE: [Loading boolean values]
// Protect against invalid boolean values by loading as a byte
// first, then converting to bool (see gh-54789).
return *reinterpret_cast<const unsigned char*>(src);
Expand All @@ -29,4 +30,9 @@ C10_HOST_DEVICE T load(const void* src) {
return c10::detail::LoadImpl<T>::apply(src);
}

template <typename scalar_t>
C10_HOST_DEVICE scalar_t load(const scalar_t* src) {
return c10::detail::LoadImpl<scalar_t>::apply(src);
}

} // namespace c10
Loading

0 comments on commit cd9e158

Please sign in to comment.