Skip to content

Commit

Permalink
Revert "Batch Norm Consolidation (pytorch#116092)"
Browse files Browse the repository at this point in the history
This reverts commit 5680f56.

Reverted pytorch#116092 on behalf of https://github.com/jeffdaily due to broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't ([comment](pytorch#116092 (comment)))
  • Loading branch information
pytorchmergebot committed Mar 6, 2024
1 parent 8dd4b6a commit b529c19
Show file tree
Hide file tree
Showing 35 changed files with 70 additions and 753 deletions.
150 changes: 30 additions & 120 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@
#include <ATen/ops/_native_batch_norm_legit_native.h>
#include <ATen/ops/_native_batch_norm_legit_no_training.h>
#include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
#include <ATen/ops/_batch_norm_with_update.h>
#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/_batch_norm_no_update.h>
#include <ATen/ops/_batch_norm_no_update_native.h>
#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/alias.h>
#include <ATen/ops/batch_norm.h>
#include <ATen/ops/batch_norm_native.h>
Expand Down Expand Up @@ -484,58 +479,10 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
return std::make_tuple(grad_input, grad_weight, grad_bias);
}

BatchNormBackend _select_batch_norm_backend(
const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
const Tensor& running_var, bool training, double eps) {

auto& ctx = at::globalContext();
bool cudnn_enabled = ctx.userEnabledCuDNN();

if (
input.is_cuda()
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
&& (input.scalar_type() != at::kHalf
|| weight.scalar_type() == at::kFloat)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& (input.dim() >= 3)
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
&& detail::getCUDAHooks().compiledWithCuDNN()
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
) {
return BatchNormBackend::Cudnn;
}

if (
input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
&& input.scalar_type() != at::kBFloat16
&& (weight.scalar_type() != at::kHalf)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
) {
return BatchNormBackend::Miopen;
}

return BatchNormBackend::Native;
}


// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
// of backends, while enabling it to keep the information about the used backend, so that it can
// use its corresponding backward implementation.
// XXX: The indices of backends need to be kept synchronized between this function and its _backward.
// TODO: remove cudnn_enabled arg
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
Expand Down Expand Up @@ -580,9 +527,24 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
}

BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
const bool use_cudnn = (
input.is_cuda()
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
&& (input.scalar_type() != at::kHalf
|| weight.scalar_type() == at::kFloat)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& (input.dim() >= 3)
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
&& detail::getCUDAHooks().compiledWithCuDNN()
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
);

if (backend == BatchNormBackend::Cudnn) {
if (use_cudnn) {
auto input_c = input.contiguous(input.suggest_memory_format());
auto weight_c = weight.contiguous();
auto bias_c = bias.contiguous();
Expand All @@ -599,7 +561,19 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(

Tensor reserve = at::empty({0}, input.options().dtype(kByte));

if (backend == BatchNormBackend::Miopen) {
bool use_miopen = (input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
&& input.scalar_type() != at::kBFloat16
&& (weight.scalar_type() != at::kHalf)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
);

if (use_miopen && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d) {
return std::tuple_cat(
at::miopen_batch_norm(
input.contiguous(), weight.contiguous(), bias.contiguous(),
Expand Down Expand Up @@ -663,7 +637,6 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
}

// TODO: remove cudnn_enabled arg
Tensor batch_norm(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
Expand All @@ -674,30 +647,6 @@ Tensor batch_norm(
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
training, momentum, eps, cudnn_enabled));
// TODO: switch to the new stack after the 2 week FC window
// if (training) {
// BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
// if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
// auto input_c = input;
// if (backend == BatchNormBackend::Cudnn) {
// input_c = input.contiguous(input.suggest_memory_format());
// } else {
// input_c = input.contiguous();
// }
// auto weight_c = weight.contiguous();
// auto bias_c = bias.contiguous();
// auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
// auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
// return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
// const_cast<Tensor&>(rvar_c), momentum, eps));
// } else {
// return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
// const_cast<Tensor&>(running_var), momentum, eps));
// }
// } else {
// return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
// momentum, eps));
// }
}

Tensor instance_norm(
Expand Down Expand Up @@ -849,38 +798,6 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const c10:
return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
}

std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
Tensor output, save_mean, save_var;
std::tie(output, save_mean, save_var) =
batch_norm_cpu(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps);
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}

std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
std::tie(out, save_mean, save_var) =
batch_norm_cpu_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
}


std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
double momentum, double eps) {
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
Tensor output, save_mean, save_var;
std::tie(output, save_mean, save_var) =
batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}

std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Expand Down Expand Up @@ -909,13 +826,6 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const T
return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
}

std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
}

std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
bool train, double eps, std::array<bool,3> grad_input_mask) {
Expand Down
8 changes: 0 additions & 8 deletions aten/src/ATen/native/Normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,4 @@ namespace at::native {
using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);

enum class BatchNormBackend {
Native,
Cudnn,
Miopen,
};

TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);

} // namespace at::native
78 changes: 0 additions & 78 deletions aten/src/ATen/native/cuda/Normalization.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/native/Normalization.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/ReduceOps.h>
#include <ATen/native/Resize.h>
Expand All @@ -14,21 +12,15 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/batch_norm_backward_elemt_native.h>
#include <ATen/ops/batch_norm_backward_reduce_native.h>
#include <ATen/ops/batch_norm_elemt_native.h>
#include <ATen/ops/batch_norm_gather_stats_native.h>
#include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
#include <ATen/ops/batch_norm_stats_native.h>
#include <ATen/ops/batch_norm_update_stats_native.h>
#include <ATen/ops/cudnn_batch_norm.h>
#include <ATen/ops/cudnn_batch_norm_backward.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/from_blob.h>
#include <ATen/ops/miopen_batch_norm.h>
#include <ATen/ops/miopen_batch_norm_backward.h>
#include <ATen/ops/native_batch_norm_backward_native.h>
#include <ATen/ops/native_batch_norm_native.h>
#include <ATen/ops/scalar_tensor.h>
Expand Down Expand Up @@ -481,54 +473,6 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const c10
return std::make_tuple(output, save_mean, save_invstd);
}

std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
Tensor output, save_mean, save_var, reserve;

BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
if (backend == BatchNormBackend::Cudnn) {
std::tie(output, save_mean, save_var, reserve) =
at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else if (backend == BatchNormBackend::Miopen) {
reserve = at::empty({0}, input.options().dtype(kByte));
std::tie(output, save_mean, save_var) =
at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else {
reserve = at::empty({0}, input.options().dtype(kByte));
std::tie(output, save_mean, save_var) =
batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
}
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}

std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});

BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
if (backend == BatchNormBackend::Cudnn) {
std::tie(out, save_mean, save_var, reserve) =
at::cudnn_batch_norm_out(out, save_mean, save_var, reserve, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else if (backend == BatchNormBackend::Miopen) {
std::tie(out, save_mean, save_var) =
at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else {
std::tie(out, save_mean, save_var) =
batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
}
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
}

std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
}
Expand All @@ -545,28 +489,6 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cuda_out(const
return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
}

std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
const Tensor& dummy_bias = at::empty(1);
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});

BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);

if (backend == BatchNormBackend::Cudnn) {
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
} else if (backend == BatchNormBackend::Miopen) {
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
} else {
return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
}
}

std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
Expand Down
Loading

0 comments on commit b529c19

Please sign in to comment.