Skip to content

Commit

Permalink
Batch Norm Consolidation (pytorch#116092)
Browse files Browse the repository at this point in the history
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in pytorch#111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: pytorch#111384

Differential Revision: [D54805279](https://our.internmc.facebook.com/intern/diff/D54805279)
Co-authored-by: Tugsbayasgalan Manlaibaatar <[email protected]>
Pull Request resolved: pytorch#116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
  • Loading branch information
andrewor14 authored and pytorchmergebot committed Mar 18, 2024
1 parent a17cd22 commit 773ae81
Show file tree
Hide file tree
Showing 36 changed files with 779 additions and 72 deletions.
150 changes: 120 additions & 30 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
#include <ATen/ops/_native_batch_norm_legit_native.h>
#include <ATen/ops/_native_batch_norm_legit_no_training.h>
#include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
#include <ATen/ops/_batch_norm_with_update.h>
#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/_batch_norm_no_update.h>
#include <ATen/ops/_batch_norm_no_update_native.h>
#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/alias.h>
#include <ATen/ops/batch_norm.h>
#include <ATen/ops/batch_norm_native.h>
Expand Down Expand Up @@ -479,10 +484,58 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
return std::make_tuple(grad_input, grad_weight, grad_bias);
}

BatchNormBackend _select_batch_norm_backend(
const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
const Tensor& running_var, bool training, double eps) {

auto& ctx = at::globalContext();
bool cudnn_enabled = ctx.userEnabledCuDNN();

if (
input.is_cuda()
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
&& (input.scalar_type() != at::kHalf
|| weight.scalar_type() == at::kFloat)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& (input.dim() >= 3)
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
&& detail::getCUDAHooks().compiledWithCuDNN()
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
) {
return BatchNormBackend::Cudnn;
}

if (
input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
&& input.scalar_type() != at::kBFloat16
&& (weight.scalar_type() != at::kHalf)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
) {
return BatchNormBackend::Miopen;
}

return BatchNormBackend::Native;
}


// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
// of backends, while enabling it to keep the information about the used backend, so that it can
// use its corresponding backward implementation.
// XXX: The indices of backends need to be kept synchronized between this function and its _backward.
// TODO: remove cudnn_enabled arg
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
Expand Down Expand Up @@ -527,24 +580,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
}

const bool use_cudnn = (
input.is_cuda()
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
&& (input.scalar_type() != at::kHalf
|| weight.scalar_type() == at::kFloat)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& (input.dim() >= 3)
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
&& detail::getCUDAHooks().compiledWithCuDNN()
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
);
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);

if (use_cudnn) {
if (backend == BatchNormBackend::Cudnn) {
auto input_c = input.contiguous(input.suggest_memory_format());
auto weight_c = weight.contiguous();
auto bias_c = bias.contiguous();
Expand All @@ -561,19 +599,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(

Tensor reserve = at::empty({0}, input.options().dtype(kByte));

bool use_miopen = (input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
&& input.scalar_type() != at::kBFloat16
&& (weight.scalar_type() != at::kHalf)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
);

if (use_miopen && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d) {
if (backend == BatchNormBackend::Miopen) {
return std::tuple_cat(
at::miopen_batch_norm(
input.contiguous(), weight.contiguous(), bias.contiguous(),
Expand Down Expand Up @@ -637,6 +663,7 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
}

// TODO: remove cudnn_enabled arg
Tensor batch_norm(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
Expand All @@ -647,6 +674,30 @@ Tensor batch_norm(
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
training, momentum, eps, cudnn_enabled));
// TODO: switch to the new stack after the 2 week FC window
// if (training) {
// BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
// if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
// auto input_c = input;
// if (backend == BatchNormBackend::Cudnn) {
// input_c = input.contiguous(input.suggest_memory_format());
// } else {
// input_c = input.contiguous();
// }
// auto weight_c = weight.contiguous();
// auto bias_c = bias.contiguous();
// auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
// auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
// return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
// const_cast<Tensor&>(rvar_c), momentum, eps));
// } else {
// return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
// const_cast<Tensor&>(running_var), momentum, eps));
// }
// } else {
// return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
// momentum, eps));
// }
}

Tensor instance_norm(
Expand Down Expand Up @@ -798,6 +849,38 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const c10:
return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
}

std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
Tensor output, save_mean, save_var;
std::tie(output, save_mean, save_var) =
batch_norm_cpu(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps);
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}

std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
std::tie(out, save_mean, save_var) =
batch_norm_cpu_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
}


std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
double momentum, double eps) {
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
Tensor output, save_mean, save_var;
std::tie(output, save_mean, save_var) =
batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}

std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Expand Down Expand Up @@ -826,6 +909,13 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const T
return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
}

std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
}

std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
bool train, double eps, std::array<bool,3> grad_input_mask) {
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/Normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,12 @@ namespace at::native {
using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);

enum class BatchNormBackend {
Native,
Cudnn,
Miopen,
};

TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);

} // namespace at::native
78 changes: 78 additions & 0 deletions aten/src/ATen/native/cuda/Normalization.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/native/Normalization.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/ReduceOps.h>
#include <ATen/native/Resize.h>
Expand All @@ -12,15 +14,21 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/batch_norm_backward_elemt_native.h>
#include <ATen/ops/batch_norm_backward_reduce_native.h>
#include <ATen/ops/batch_norm_elemt_native.h>
#include <ATen/ops/batch_norm_gather_stats_native.h>
#include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
#include <ATen/ops/batch_norm_stats_native.h>
#include <ATen/ops/batch_norm_update_stats_native.h>
#include <ATen/ops/cudnn_batch_norm.h>
#include <ATen/ops/cudnn_batch_norm_backward.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/from_blob.h>
#include <ATen/ops/miopen_batch_norm.h>
#include <ATen/ops/miopen_batch_norm_backward.h>
#include <ATen/ops/native_batch_norm_backward_native.h>
#include <ATen/ops/native_batch_norm_native.h>
#include <ATen/ops/scalar_tensor.h>
Expand Down Expand Up @@ -473,6 +481,54 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const c10
return std::make_tuple(output, save_mean, save_invstd);
}

std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
Tensor output, save_mean, save_var, reserve;

BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
if (backend == BatchNormBackend::Cudnn) {
std::tie(output, save_mean, save_var, reserve) =
at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else if (backend == BatchNormBackend::Miopen) {
reserve = at::empty({0}, input.options().dtype(kByte));
std::tie(output, save_mean, save_var) =
at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else {
reserve = at::empty({0}, input.options().dtype(kByte));
std::tie(output, save_mean, save_var) =
batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
}
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}

std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});

BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
if (backend == BatchNormBackend::Cudnn) {
std::tie(out, save_mean, save_var, reserve) =
at::cudnn_batch_norm_out(out, save_mean, save_var, reserve, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else if (backend == BatchNormBackend::Miopen) {
std::tie(out, save_mean, save_var) =
at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else {
std::tie(out, save_mean, save_var) =
batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
}
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
}

std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
}
Expand All @@ -489,6 +545,28 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cuda_out(const
return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
}

std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
const Tensor& dummy_bias = at::empty(1);
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});

BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);

if (backend == BatchNormBackend::Cudnn) {
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
} else if (backend == BatchNormBackend::Miopen) {
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
} else {
return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
}
}

std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
Expand Down
Loading

0 comments on commit 773ae81

Please sign in to comment.