Skip to content

Commit

Permalink
add Half support for maxpool on CPU (pytorch#98819)
Browse files Browse the repository at this point in the history
### Testing
Single socket (28 cores):

shape | fp32 forward / ms | fp16 forward / ms | bf16 forward / ms | fp32 backward / ms | fp16 backward / ms | bf16 backward / ms
-- | -- | -- | -- | -- | -- | --
size: (1, 56, 264, 264), kernel: 3,   stride: 1, mem_format: contig | 4.12895 | 6.9669 | 5.30297 | 0.55775 | 1.98917 | 0.72233
size: (1, 56, 264, 264), kernel: 3,   stride: 1, mem_format: CL | 0.85093 | 1.88813 | 1.38063 | 5.5742 | 36.5086 | 10.58552
size: (32, 16, 200, 200), kernel: 3,   stride: 1, mem_format: contig | 22.37212 | 37.90383 | 30.94482 | 6.85868 | 10.6116 | 3.9993
size: (32, 16, 200, 200), kernel: 3,   stride: 1, mem_format: CL | 5.41658 | 4.71098 | 4.66578 | 6.69875 | 14.7171 | 5.1167
size: (32, 32, 100, 100), kernel: 3,   stride: 1, mem_format: contig | 10.69831 | 18.0468 | 13.71657 | 2.61192 | 4.96172 | 1.68635
size: (32, 32, 100, 100), kernel: 3,   stride: 1, mem_format: CL | 2.52637 | 2.0096 | 2.0055 | 2.60314 | 7.2093 | 2.49843
size: (4, 19, 10, 16, 16), kernel: 3,   stride: 1, mem_format: contig | 0.47605 | 0.88398 | 0.65326 | 0.06525 | 0.115489 | 0.0674
size: (4, 19, 10, 16, 16), kernel: 3,   stride: 1, mem_format: CL3d | 0.10902 | 0.25293 | 0.157475 | 0.11386 | 0.53319 | 0.17836

Single core:

shape | fp32 forward / ms | fp16 forward / ms | bf16 forward / ms | fp32 backward / ms | fp16 backward / ms | bf16 backward / ms
-- | -- | -- | -- | -- | -- | --
size: (1, 56, 264, 264), kernel: 3,   stride: 1, mem_format: contig | 90.9809 | 163.473 | 126.1276 | 6.57721 | 41.40833 | 11.82505
size: (1, 56, 264, 264), kernel: 3,   stride: 1, mem_format: CL | 9.88405 | 38.39137 | 29.62069 | 7.10636 | 36.97535 | 11.0525
size: (32, 16, 200, 200), kernel: 3,   stride: 1, mem_format: contig | 476.782 | 855.4769 | 648.2248 | 46.6488 | 219.2586 | 67.10599
size: (32, 16, 200, 200), kernel: 3,   stride: 1, mem_format: CL | 80.29271 | 91.33854 | 87.80345 | 48.81692 | 203.9974 | 63.39004
size: (32, 32, 100, 100), kernel: 3,   stride: 1, mem_format: contig | 235.2113 | 419.0799 | 315.4284 | 20.6049 | 107.1524 | 32.39169
size: (32, 32, 100, 100), kernel: 3,   stride: 1, mem_format: CL | 29.47653 | 33.54905 | 32.82823 | 22.59674 | 98.5586 | 30.05763
size: (4, 19, 10, 16, 16), kernel: 3,   stride: 1, mem_format: contig | 7.90684 | 13.9208 | 10.03272 | 0.23725 | 1.35269 | 0.41728
size: (4, 19, 10, 16, 16), kernel: 3,   stride: 1, mem_format: CL3d | 2.33638 | 3.36894 | 2.64635 | 0.26535 | 1.244 | 0.38895

Pull Request resolved: pytorch#98819
Approved by: https://github.com/mingfeima, https://github.com/mikaylagawarecki
  • Loading branch information
CaoE authored and pytorchmergebot committed Sep 5, 2023
1 parent 1e0e55c commit 42f94d7
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 99 deletions.
107 changes: 57 additions & 50 deletions aten/src/ATen/native/cpu/MaxPoolKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/native/Pool.h>
#include <ATen/native/cpu/utils.h>
#include <c10/util/irange.h>
Expand Down Expand Up @@ -60,13 +61,15 @@ vec::Vectorized<int64_t> is_nan_vec<int64_t>(vec::Vectorized<int64_t> vec) {
return ret;
}

template <typename scalar_t, typename accscalar_t>
inline void compute_internal(
template <typename scalar_t, typename opmath_t>
inline
typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
compute_internal(
scalar_t* input_data,
scalar_t* out_data,
accscalar_t* max_ptr,
vec::int_same_size_t<accscalar_t>* index_ptr,
int64_t* ind,
opmath_t* max_ptr,
vec::int_same_size_t<opmath_t>* index_ptr,
int64_t* ind,
int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
int64_t n,
int64_t len,
Expand All @@ -78,7 +81,7 @@ inline void compute_internal(
int64_t dilationH,
int64_t dilationW) {
using Vec = vec::Vectorized<scalar_t>;
using integer_t = vec::int_same_size_t<accscalar_t>;
using integer_t = vec::int_same_size_t<opmath_t>;
using iVec = vec::Vectorized<integer_t>;
// Pass I: init out lane
iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
Expand Down Expand Up @@ -130,13 +133,16 @@ inline void compute_internal(
}
}

template <>
inline void compute_internal(
BFloat16* input_data,
BFloat16* out_data,
float* max_ptr,
int32_t* index_ptr,
int64_t* ind,
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
template <typename scalar_t, typename opmath_t>
inline
typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
compute_internal(
scalar_t* input_data,
scalar_t* out_data,
opmath_t* max_ptr,
vec::int_same_size_t<opmath_t>* index_ptr,
int64_t* ind,
int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
int64_t n,
int64_t len,
Expand All @@ -147,34 +153,34 @@ inline void compute_internal(
int64_t dilationD,
int64_t dilationH,
int64_t dilationW) {
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
using Vec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<opmath_t>;
using iVec = vec::Vectorized<int32_t>;
// Pass I: init out lane
iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
fVec out_vec = fVec(-std::numeric_limits<float>::infinity());
fVec out_vec = fVec(-std::numeric_limits<opmath_t>::infinity());
int64_t d1 = 0;
for (; d1 < len; d1 += fVec::size()) {
index0_vec.store(index_ptr + d1);
out_vec.store(max_ptr + d1);
}
for (; d1 < size; d1++) {
ind[d1] = ih0 * input_width + iw0;
max_ptr[d1] = -std::numeric_limits<float>::infinity();
max_ptr[d1] = -std::numeric_limits<opmath_t>::infinity();
}
// Pass II: compute local max
for (int64_t id = id0; id < id1; id += dilationD) {
for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
BFloat16* in = input_data + (n * input_depth * input_height * input_width +
scalar_t* in = input_data + (n * input_depth * input_height * input_width +
id * input_height * input_width + ih * input_width + iw) * channels;

int64_t d2 = 0;
for (; d2 < len; d2 += bVec::size()) {
for (; d2 < len; d2 += Vec::size()) {
iVec index_ivec = iVec(id * input_height * input_width + ih * input_width + iw);
bVec val_bvec = bVec::loadu(in + d2);
Vec val_bvec = Vec::loadu(in + d2);
fVec val_fvec0, val_fvec1;
std::tie(val_fvec0, val_fvec1) = convert_bfloat16_float(val_bvec);
std::tie(val_fvec0, val_fvec1) = convert_to_float<scalar_t>(val_bvec);

iVec maxindex_ivec0 = iVec::loadu(index_ptr + d2);
iVec maxindex_ivec1 = iVec::loadu(index_ptr + d2 + iVec::size());
Expand All @@ -200,9 +206,9 @@ inline void compute_internal(
}
for (; d2 < size; d2++) {
int64_t index = id * input_height * input_width + ih * input_width + iw;
float val = float(in[d2]);
opmath_t val = opmath_t(in[d2]);
int64_t maxindex = ind[d2];
float maxval = max_ptr[d2];
opmath_t maxval = max_ptr[d2];

bool mask = (val > maxval) || std::isnan(val);
max_ptr[d2] = mask ? val : maxval;
Expand All @@ -211,16 +217,16 @@ inline void compute_internal(
}
}
}
// Convert max values from float to bfloat16
// Convert max values from float to bfloat16/half
int64_t d3 = 0;
for (; d3 < len; d3 += bVec::size()) {
for (; d3 < len; d3 += Vec::size()) {
fVec max_fvec0 = fVec::loadu(max_ptr + d3);
fVec max_fvec1 = fVec::loadu(max_ptr + d3 + fVec::size());
bVec max_bvec = convert_float_bfloat16(max_fvec0, max_fvec1);
Vec max_bvec = convert_from_float<scalar_t>(max_fvec0, max_fvec1);
max_bvec.store(out_data + d3);
}
for (; d3 < size; d3++) {
out_data[d3] = BFloat16(max_ptr[d3]);
out_data[d3] = scalar_t(max_ptr[d3]);
}
}

Expand Down Expand Up @@ -281,7 +287,7 @@ void cpu_max_pool(
int64_t output_height = output.size(-2);
int64_t output_width = output.size(-1);

using accscalar_t = at::opmath_type<scalar_t>;
using opmath_t = at::opmath_type<scalar_t>;
// parallel on dim N, C
at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
for (int64_t c = begin; c < end; c++) {
Expand All @@ -306,17 +312,18 @@ void cpu_max_pool(

// compute local max
int64_t maxindex = id0 * input_height * input_width + ih0 * input_width + iw0;
accscalar_t maxval;
if (std::numeric_limits<accscalar_t>::has_infinity) {
maxval = -std::numeric_limits<accscalar_t>::infinity();
opmath_t maxval;
if (std::numeric_limits<opmath_t>::has_infinity) {
maxval = -std::numeric_limits<opmath_t>::infinity();
} else {
maxval = std::numeric_limits<accscalar_t>::min();
maxval = std::numeric_limits<opmath_t>::min();
}

for (int64_t id = id0; id < id1; id += dilationD) {
for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
int64_t index = id * input_height * input_width + ih * input_width + iw;
accscalar_t val = input_ptr[index];
opmath_t val = input_ptr[index];
if ((val > maxval) || is_nan(static_cast<double>(val))) {
maxval = val;
maxindex = index;
Expand Down Expand Up @@ -396,9 +403,9 @@ void cpu_max_pool_channels_last(
int64_t output_height = output.size(-2);
int64_t output_width = output.size(-1);

using accscalar_t = at::opmath_type<scalar_t>;
using opmath_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<scalar_t>;
using integer_t = vec::int_same_size_t<accscalar_t>;
using integer_t = vec::int_same_size_t<opmath_t>;
// for the convience of vectorization, use integer of the same size of scalar_t,
// e.g. int32_t for float, int64_t for double
// need to make sure doesn't overflow
Expand All @@ -418,11 +425,11 @@ void cpu_max_pool_channels_last(
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
std::unique_ptr<integer_t []> index_buffer(new integer_t[len]);
integer_t * index_ptr = index_buffer.get();
// temp buffer holding max value with accscalar_t
std::unique_ptr<accscalar_t []> max_arr;
accscalar_t* max_ptr = nullptr;
if (!std::is_same<scalar_t, accscalar_t>::value) {
max_arr = std::make_unique<accscalar_t[]>(size);
// temp buffer holding max value with opmath_t
std::unique_ptr<opmath_t []> max_arr;
opmath_t* max_ptr = nullptr;
if (!std::is_same<scalar_t, opmath_t>::value) {
max_arr = std::make_unique<opmath_t[]>(size);
max_ptr = max_arr.get();
}

Expand Down Expand Up @@ -598,13 +605,13 @@ void max_pool2d_kernel_impl(
int dilationW, int dilationH) {
switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d", [&] {
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d", [&] {
cpu_max_pool<scalar_t, /*is 3d*/false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
});
break;
}
case at::MemoryFormat::ChannelsLast: {
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d_channels_last", [&] {
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d_channels_last", [&] {
cpu_max_pool_channels_last<scalar_t, false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
});
break;
Expand Down Expand Up @@ -637,7 +644,7 @@ void max_pool3d_kernel_impl(
DimVector indices_sizes(indices.sizes().begin(), indices.sizes().end());
indices_sizes.insert(indices_sizes.begin(), 1);
indices.resize_(indices_sizes, at::MemoryFormat::ChannelsLast3d);
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d_channels_last", [&] {
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
cpu_max_pool_channels_last<scalar_t, /*is 3d*/true>(output, indices, input_cl_check,
{kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
});
Expand All @@ -648,14 +655,14 @@ void max_pool3d_kernel_impl(
}
switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d", [&] {
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d", [&] {
cpu_max_pool<scalar_t, /*is 3d*/true>(output, indices, input,
{kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
});
break;
}
case at::MemoryFormat::ChannelsLast3d: {
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d_channels_last", [&] {
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
cpu_max_pool_channels_last<scalar_t, true>(output, indices, input,
{kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
});
Expand All @@ -672,13 +679,13 @@ void max_pool2d_backward_kernel_impl(
const Tensor& indices) {
switch (grad_output.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward", [&] {
cpu_max_pool_backward<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
});
break;
}
case at::MemoryFormat::ChannelsLast: {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
cpu_max_pool_backward_channels_last<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
});
break;
Expand All @@ -705,7 +712,7 @@ void max_pool3d_backward_kernel_impl(
sizes.insert(sizes.begin(), 1);
grad_input.resize_(sizes, at::MemoryFormat::ChannelsLast3d);
auto _indices = indices.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d);
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output_cl_check, _indices);
});
grad_input.squeeze_(0);
Expand All @@ -714,13 +721,13 @@ void max_pool3d_backward_kernel_impl(
}
switch (grad_output.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward", [&] {
cpu_max_pool_backward<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
});
break;
}
case at::MemoryFormat::ChannelsLast3d: {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
});
break;
Expand Down
45 changes: 25 additions & 20 deletions aten/src/ATen/native/cpu/MaxPooling.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/core/Tensor.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/MaxPooling.h>
#include <c10/util/irange.h>
Expand Down Expand Up @@ -31,25 +31,30 @@ void max_pool1d_impl(
Tensor& output,
const Tensor& input,
const PoolingParams1D& p) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool1d_impl", [&] {
const Tensor in = input.contiguous();
scalar_t* const OP = output.data_ptr<scalar_t>();
const scalar_t* const IP = in.data_ptr<scalar_t>();

// Value used for padding
scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
? -std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::lowest();

at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
for (const auto it : c10::irange(begin, end)) {
scalar_t* op = OP + it * p.OW;
const scalar_t* ip = IP + it * p.IW;
std::fill_n(op, p.OW, FILL);
max_pool1d_kernel(op, ip, p);
}
});
});
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16,
ScalarType::Half,
input.scalar_type(),
"max_pool1d_impl",
[&] {
const Tensor in = input.contiguous();
scalar_t* const OP = output.data_ptr<scalar_t>();
const scalar_t* const IP = in.data_ptr<scalar_t>();

// Value used for padding
scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
? -std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::lowest();

at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
for (const auto it : c10::irange(begin, end)) {
scalar_t* op = OP + it * p.OW;
const scalar_t* ip = IP + it * p.IW;
std::fill_n(op, p.OW, FILL);
max_pool1d_kernel(op, ip, p);
}
});
});
}

} // namespace
Expand Down
Loading

0 comments on commit 42f94d7

Please sign in to comment.