Skip to content

Commit

Permalink
optimize channels last for BatchNorm2d on CPU (pytorch#48919)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#48919

move data indexing utils

parallel inference contiguous path

parallel inference channels last path

add dim apply

optimize update stats

add channels last support for backward

Revert "add channels last support for backward"

This reverts commit cc5e29dce44395250f8e2abf9772f0b99f4bcf3a.

Revert "optimize update stats"

This reverts commit 7cc6540701448b9cfd5833e36c745b5015ae7643.

Revert "add dim apply"

This reverts commit b043786d8ef72dee5cf85b5818fcb25028896ecd.

bug fix

add batchnorm nhwc test for cpu, including C=1 and HW=1

Test Plan: Imported from OSS

Reviewed By: glaringlee

Differential Revision: D25399468

Pulled By: VitalyFedyunin

fbshipit-source-id: a4cd7a09cd4e1a8f5cdd79c7c32c696d0db386bd
  • Loading branch information
mingfeima authored and facebook-github-bot committed May 14, 2021
1 parent 0d11dbf commit 0be334a
Show file tree
Hide file tree
Showing 7 changed files with 807 additions and 154 deletions.
29 changes: 29 additions & 0 deletions aten/src/ATen/cpu/vec256/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,33 @@ inline void map3(
}
}

template <typename scalar_t, typename Op>
inline void map4(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* input_data1,
const scalar_t* input_data2,
const scalar_t* input_data3,
const scalar_t* input_data4,
int64_t size) {
using Vec = vec256::Vec256<scalar_t>;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec1 = Vec::loadu(input_data1 + d);
Vec data_vec2 = Vec::loadu(input_data2 + d);
Vec data_vec3 = Vec::loadu(input_data3 + d);
Vec data_vec4 = Vec::loadu(input_data4 + d);
Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
output_vec.store(output_data + d);
}
if (size - d > 0) {
Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
Vec data_vec4 = Vec::loadu(input_data4 + d, size - d);
Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
output_vec.store(output_data + d, size - d);
}
}

}} // namespace at::vec256
176 changes: 59 additions & 117 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ static const int MIOPEN_DIM_MAX = 5;
namespace at { namespace native {

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(batch_norm_cpu_inference_contiguous_stub);
DEFINE_DISPATCH(batch_norm_cpu_stub);
DEFINE_DISPATCH(batch_norm_cpu_collect_stats_stub);
DEFINE_DISPATCH(batch_norm_cpu_backward_stub);

namespace {
void check_dims_match_num_input_features(const char* arg_name, int64_t expected, int64_t actual){
Expand All @@ -33,15 +35,6 @@ namespace {
}
}

// TensorAccessor when it is defined to work around undefined...
template <typename scalar_t>
static TensorAccessor<scalar_t, 1> conditional_accessor_1d(const Tensor& t) {
if (! t.defined()) {
return TensorAccessor<scalar_t, 1>(nullptr, nullptr, nullptr);
}
return t.accessor<scalar_t, 1>();
}

template<typename T>
struct InvStd {
T operator()(T var, double epsilon) const {
Expand All @@ -60,87 +53,8 @@ struct Var {
}
};

template<typename scalar_t>
void batch_norm_cpu_inference_collect_linear_and_constant_terms(
scalar_t* alpha, scalar_t* beta, int64_t n_channel,
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& mean, const Tensor& variance, double eps) {

const scalar_t* weight_data = weight.defined() ? weight.data_ptr<scalar_t>() : nullptr;
const scalar_t* bias_data = bias.defined() ? bias.data_ptr<scalar_t>() : nullptr;
const scalar_t* mean_data = mean.data_ptr<scalar_t>();
const scalar_t* var_data = variance.data_ptr<scalar_t>();

/// Collect the linear and constant terms regarding the input.
/// output(n, c, h, w)
/// = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c)
/// + bias(c)
/// = input(n, c, h, w) * inv_var(c) * weight(c)
/// - mean(c) * inv_var(c) * weight(c) + bias(c),
/// where inv_var(c) = 1 / sqrt(var(c) + eps).
/// So the linear term, alpha(c) = inv_var(c) * weight(c),
/// the constant term beta(c) = bias(c) - mean(c) * inv_var(c) * weight(c)
/// Note that this is only a good idea if (input_size >> c), in degenerate
/// cases where image_size == 1 && batch_size == 1, it is slow.
for (int64_t c = 0; c < n_channel; c++) {
scalar_t inv_var = 1 / std::sqrt(var_data[c] + static_cast<scalar_t>(eps));
scalar_t weight_v = weight_data ? weight_data[c] : 1;
scalar_t bias_v = bias_data ? bias_data[c] : 0;
alpha[c] = inv_var * weight_v;
beta[c] = bias_v - mean_data[c] * inv_var * weight_v;
}
}

/// A fast path for CPU inference when all tensors are channels last contiguous.
/// This code achieves machine bandwidth peak without AVX support.
/// If this changes for future architectures, we can move it to the cpu/
/// directory.
template<typename scalar_t>
void batch_norm_cpu_inference_channels_last(Tensor& output, const Tensor& input,
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& mean, const Tensor& variance, double eps) {

int64_t n_batch = input.size(0);
int64_t n_channel = input.size(1);
int64_t image_size = input.numel() / n_batch / n_channel;

scalar_t* output_data = output.data_ptr<scalar_t>();
const scalar_t* input_data = input.data_ptr<scalar_t>();

Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
scalar_t* alpha_data = alpha.data_ptr<scalar_t>();
scalar_t* beta_data = beta.data_ptr<scalar_t>();

batch_norm_cpu_inference_collect_linear_and_constant_terms<scalar_t>(
alpha_data, beta_data, n_channel, weight, bias, mean, variance, eps);

// Apply the linear terms to the input,
// output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
// No need to use parallel_for as this function is supposed to be
// memory-limited.
// Keep the loop structure simple to make sure compiler vectorization kicks in.
if (n_channel != 1) {
for (int64_t n = 0; n < n_batch; ++n) {
for (int64_t i = 0; i < image_size; ++i) {
for (int64_t c = 0; c < n_channel; ++c) {
// Keep all the offset calculation within the inner loop for
// simplicity. Compilers are very good at hoisting the common part
// outside.
int64_t offset = n * image_size * n_channel + i * n_channel + c;
output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c];
}
}
}
} else {
// n_channel == 1
for (int64_t n = 0; n < n_batch; ++n) {
for (int64_t i = 0; i < image_size; ++i) {
int64_t offset = n * image_size + i;
output_data[offset] = input_data[offset] * alpha_data[0] + beta_data[0];
}
}
}
static inline bool is_contiguous(const Tensor& t) {
return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast);
}

template<typename scalar_t>
Expand All @@ -150,29 +64,18 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
bool train, double eps) {

// Check if we should use the fast path for contiguous memory format
if (!train && input.is_contiguous()
bool all_contiguous = is_contiguous(input)
&& (!weight.defined() || weight.is_contiguous())
&& (!bias.defined() || bias.is_contiguous())
&& running_mean.is_contiguous()
&& running_var.is_contiguous()) {
&& running_var.is_contiguous();

Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
batch_norm_cpu_inference_contiguous_stub(kCPU, output, input, weight,
bias, running_mean, running_var, eps);
return std::make_tuple(output, save_mean, save_invstd);
}
Tensor output = at::empty_like(input, input.suggest_memory_format());

// Check if we should use the fast path for channel last memory format
if (!train && input.is_contiguous(at::MemoryFormat::ChannelsLast)
&& (!weight.defined() || weight.is_contiguous())
&& (!bias.defined() || bias.is_contiguous())
&& running_mean.is_contiguous()
&& running_var.is_contiguous()) {

Tensor output = at::empty_like(input, at::MemoryFormat::ChannelsLast);
batch_norm_cpu_inference_channels_last<scalar_t>(
output, input, weight, bias, running_mean, running_var, eps);
// inference contiguous path
if (all_contiguous) {
batch_norm_cpu_stub(kCPU, output, input, weight, bias,
save_mean, save_invstd, running_mean, running_var, train, eps);
return std::make_tuple(output, save_mean, save_invstd);
}

Expand Down Expand Up @@ -200,7 +103,6 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
auto b = bias.defined() ? as_nd(bias) :
at::detail::scalar_tensor_static(0, input.scalar_type(), kCPU);

Tensor output = at::empty(input.sizes(), input.options());
auto iter = TensorIteratorConfig()
.add_output(output)
.add_input(input)
Expand Down Expand Up @@ -242,6 +144,34 @@ std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
auto running_mean_a = conditional_accessor_1d<scalar_t>(running_mean);
auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);

bool all_contiguous = is_contiguous(input);
if (all_contiguous) {
auto _mean = at::empty({n_input}, input.options());
auto _var_sum = at::empty({n_input}, input.options());
auto _mean_a = _mean.accessor<scalar_t, 1>();
auto _var_sum_a = _var_sum.accessor<scalar_t, 1>();

batch_norm_cpu_collect_stats_stub(kCPU, _mean, _var_sum, input);

parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
for (int64_t f = b_begin; f < b_end; ++f) {
save_mean_a[f] = _mean_a[f];
save_var_transform_a[f] = VarTransform<accscalar_t>{}(_var_sum_a[f] / n, eps);

if (running_mean.defined()) {
running_mean_a[f] = momentum * _mean_a[f] + (1 - momentum) * running_mean_a[f];
}
if (running_var.defined()) {
accscalar_t unbiased_var = _var_sum_a[f] / (n - 1);
running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
}
}
});

return std::make_tuple(save_mean, save_var_transform);
}

// non-contiguous path
parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
for (int64_t f = b_begin; f < b_end; ++f) {
Tensor in = input.select(1, f);
Expand Down Expand Up @@ -270,25 +200,37 @@ std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
return std::make_tuple(save_mean, save_var_transform);
}


template<typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(const Tensor& grad_out_, const Tensor& input, const Tensor& weight,
const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
bool train, double eps, std::array<bool,3> grad_input_mask) {
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
const Tensor& grad_out_, const Tensor& input, const Tensor& weight,
const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
bool train, double eps, std::array<bool,3> grad_input_mask) {

using accscalar_t = at::acc_type<scalar_t, false>;

Tensor grad_input;
Tensor grad_weight;
Tensor grad_bias;
if (grad_input_mask[0]) {
grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
grad_input = at::empty_like(input, input.suggest_memory_format());
}
if (grad_input_mask[1]) {
grad_weight = at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
grad_weight = at::empty_like(weight, at::MemoryFormat::Contiguous);
}
if (grad_input_mask[2]) {
grad_bias = at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
grad_bias = at::empty_like(weight, at::MemoryFormat::Contiguous);
}

// since we are directly manipulating pointers in contiguous path,
// need to make sure input and grad_out have the same memory format.
bool all_contiguous = is_contiguous(input)
&& is_contiguous(grad_out_)
&& input.suggest_memory_format() == grad_out_.suggest_memory_format();

if (all_contiguous) {
batch_norm_cpu_backward_stub(kCPU, grad_input, grad_weight, grad_bias,
grad_out_, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
return std::make_tuple(grad_input, grad_weight, grad_bias);
}

auto weight_a = conditional_accessor_1d<scalar_t>(weight);
Expand Down
24 changes: 22 additions & 2 deletions aten/src/ATen/native/batch_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,29 @@ namespace at {
namespace native {

using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&,
const Tensor&, const Tensor&, const Tensor&, double);
const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
using batch_norm_collect_stats_fn = void (*)(Tensor&, Tensor&, const Tensor&);
using batch_norm_backward_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&,
const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);

DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_inference_contiguous_stub);
DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_stub);
DECLARE_DISPATCH(batch_norm_collect_stats_fn, batch_norm_cpu_collect_stats_stub);
DECLARE_DISPATCH(batch_norm_backward_fn, batch_norm_cpu_backward_stub);

// TensorAccessor when it is defined to work around undefined...
template <typename scalar_t>
static TensorAccessor<scalar_t, 1> conditional_accessor_1d(const Tensor& t) {
if (! t.defined()) {
return TensorAccessor<scalar_t, 1>(nullptr, nullptr, nullptr);
}
return t.accessor<scalar_t, 1>();
}

template <typename scalar_t>
static scalar_t* conditional_data_ptr(const Tensor& t) {
return t.defined() ? t.contiguous().data_ptr<scalar_t>()
: nullptr;
}

} // namespace native

Expand Down
Loading

0 comments on commit 0be334a

Please sign in to comment.