Skip to content

Commit

Permalink
LayerNorm Support in autodiff: (pytorch#50467)
Browse files Browse the repository at this point in the history
Summary:
1. extend autodiff by adding entry for layer_norm in symbolic script, we now use native_layer_norm_backward
2. added backward function `layernorm_double_backward` for `native_layer_norm_backward`, preserves double backward support for LayerNorm in autodiff/ScriptModule
3. added python test to verify autodiff on layer_norm with various configuration of optional tensors; (verify the fix in pytorch#49430)

Pull Request resolved: pytorch#50467

Reviewed By: eellison

Differential Revision: D30232864

Pulled By: jansel

fbshipit-source-id: b9c33075386aff96afff7415df9f94388bfb474a

Co-authored-by: Ryan Spring <[email protected]>
Co-authored-by: Jie <[email protected]>
  • Loading branch information
2 people authored and facebook-github-bot committed Aug 12, 2021
1 parent b004307 commit ed0b8a3
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 151 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/layer_norm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ void LayerNormBackwardKernelImpl(
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(),
"LayerNormBackwardKernelImpl", [&]() {
LayerNormBackwardKernelImplInternal<scalar_t>(
dY, X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
});
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ void LayerNormBackwardKernelImpl(
"LayerNormBackwardKernelImpl",
[&]() {
LayerNormBackwardKernelImplInternal<scalar_t>(
dY, X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
});
}

Expand Down
8 changes: 7 additions & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,13 @@
save_invstd: not_implemented("native_batch_norm_backward save_invstd")

- name: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_layer_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, normalized_shape, eps, grad_input_mask) : (grads[0].defined() ? native_layer_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
input, weight, bias: "grad.defined() ? native_layer_norm_backward(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"

- name: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask)
bias: Tensor()
mean: not_implemented("native_layer_norm_backward mean")
rstd: not_implemented("native_layer_norm_backward rstd")

- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, int N, int C, int HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input.is_contiguous() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
Expand Down
195 changes: 112 additions & 83 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3125,109 +3125,138 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(

}

std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_layer_norm_backward(
const Tensor& dY,
const Tensor& dmean,
const Tensor& drstd,
const Tensor& X,
const Tensor& mean,
const Tensor& rstd,
std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
const Tensor& input_t,
const c10::optional<Tensor>& gamma,
const Tensor& ggI,
const Tensor& ggG,
const Tensor& ggB,
const Tensor& gO_t,
const Tensor& save_mean_t,
const Tensor& save_invstd_t,
IntArrayRef normalized_shape,
double eps,
std::array<bool, 3> grad_input_mask) {
std::array<bool, 3> output_mask) {

const int normalized_ndim = normalized_shape.size();
const auto input_shape = X.sizes();
const auto input_ndim = X.dim();
const auto input_shape = input_t.sizes();
const auto input_ndim = input_t.dim();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const int axis = input_ndim - normalized_ndim;
const int64_t M =
c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
const int64_t N =
c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
//printf("M: %ld, N: %ld", M, N);

Tensor dX;
Tensor dgamma;
Tensor dbeta;
auto input = input_t.reshape({M, N});
auto gO = gO_t.reshape({M, N});
auto save_mean = save_mean_t.reshape({M, 1});
auto save_invstd = save_invstd_t.reshape({M, 1});

const Tensor X_tensor = X.reshape({M, N});
const Tensor mean_tensor = mean.reshape({M, 1});
const Tensor rstd_tensor = rstd.reshape({M, 1});
const double s = 1.0 / static_cast<double>(N);
bool affine = isDefined(gamma);
Tensor gamma_expanded;
Tensor ggG_expanded, ggB_expanded;
if (affine) {
gamma_expanded = gamma->reshape({1, N});
if (ggG.defined()) {
ggG_expanded = ggG.reshape({1, N});
}
if (ggB.defined()) {
ggB_expanded = ggB.reshape({1, N});
}
} else {
gamma_expanded = at::ones({1}, input.options());
}

Tensor dY_tensor;
if (dY.defined()) {
dY_tensor = dY.reshape({M, N});
Tensor ggI_expanded;
if (ggI.defined()) {
ggI_expanded = ggI.reshape({M, N});
}

if (grad_input_mask[0]) {
Tensor gamma_tensor;
if (isDefined(gamma)) {
gamma_tensor = gamma->reshape({1, N});
}
Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
Tensor var;
Tensor dvar;
if (drstd.defined()) {
var = ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
dvar = -0.5 * rstd_cube * drstd.view({M, 1});
}
Tensor ds;
Tensor db;
if (dY.defined()) {
ds = (isDefined(gamma) ? dY_tensor * X_tensor * gamma_tensor
: dY_tensor * X_tensor)
.sum(1)
.unsqueeze_(-1);
db = (isDefined(gamma) ? dY_tensor * gamma_tensor : dY_tensor)
.sum(1)
.unsqueeze_(-1);
const Tensor& a = rstd_tensor;
const Tensor b = (db * mean_tensor - ds) * rstd_cube * s;
const Tensor c = -b * mean_tensor - db * rstd_tensor * s;
if (isDefined(gamma)) {
dX = a * dY_tensor * gamma_tensor + b * X_tensor + c;
} else {
dX = a * dY_tensor + b * X_tensor + c;
}
if (dmean.defined() && drstd.defined()) {
dX += var_std_mean_backward(
{dvar, dmean.view({M, 1})},
X_tensor,
var,
mean_tensor,
/*dim=*/IntArrayRef{1},
/*correction=*/0,
/*keepdim=*/true,
/*is_std=*/false);
}
dX = dX.reshape_as(X);
} else if (dmean.defined() && drstd.defined()) {
dX = var_std_mean_backward(
{dvar, dmean.view({M, 1})},
X_tensor,
var,
mean_tensor,
/*dim=*/IntArrayRef{1},
/*correction=*/0,
/*keepdim=*/true,
/*is_std=*/false)
.reshape_as(X);
}
// for half inputs, save_mean, save_invstd are float
// (ideally, we would cast everything else, but not now)
auto mu = save_mean.to(input.scalar_type());
auto input_sub_mu = input - mu;
auto sigma2_eps_neg_1_2 = save_invstd.to(input.scalar_type());
auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);

Tensor gI;
// calculate gI
auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;

if (ggI.defined()) {

auto gxhat = gO * gamma_expanded;
auto gxhat_mu_sum = (gxhat * input_sub_mu).sum(1, true);
auto gxhat_sum = gxhat.sum(1, true);

auto ggI_sum = ggI_expanded.sum(1, true);
auto ggI_mu_sum = (ggI_expanded * input_sub_mu).sum(1, true);

auto all_sub = ((ggI_sum * gxhat_sum).div_(N)).sub_((ggI_expanded * gxhat).sum(1, true)).add_(
(sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum).mul_(3. / N));
auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(N);
auto gI_1t = (ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat);
auto gI_2t = (gxhat_mu_sum * sigma2_eps_neg_3_2).div_(N) * (ggI_sum.div(N) - ggI_expanded);

gI = (gI_0t.add_(gI_1t).add_(gI_2t));
}

if (grad_input_mask[1] && dY.defined()) {
dgamma = (dY_tensor * (X_tensor - mean_tensor) * rstd_tensor)
.sum(0)
.reshape_as(toNonOptTensor(gamma));
// add contribution of gamma term to gI
if (affine && ggG.defined()) {
auto t0 = gO * ggG_expanded * sigma2_eps_neg_1_2;
auto t1 = (sigma2_eps_neg_1_2 * (gO * ggG_expanded).sum(1, true)).div_(-N);
auto t2 = (input_mu_sigma2_neg_3_2 * (gO * ggG_expanded * input_sub_mu).sum(1,true)).div_(-N);
auto gI_G_term = t0.add_(t1).add_(t2);
gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
}
if (grad_input_mask[2] && dY.defined()) {
dbeta = dY_tensor.sum(0).reshape_as(toNonOptTensor(gamma));


if (gI.defined()) {
//printf("=== computing gI\n");
gI = gI.reshape_as(input_t);
}

return std::make_tuple(dX, dgamma, dbeta);
// this is the grad_input for the first backward function
auto first_bwd_fn_grad_input = [&](const Tensor& gO_local, const Tensor& gamma_local) -> Tensor {
auto h0 = (gamma_local * sigma2_eps_neg_1_2).div_(N);
auto h1 = (N * gO_local).sub_(gO_local.sum(1,true)).sub_(
input_sub_mu.mul(sigma2_eps_neg_1) * (gO_local * input_sub_mu).sum(1,true));
return h0 * h1;
};

// calculate gG
Tensor gG;
if (affine && ggI.defined()) {
gG = first_bwd_fn_grad_input(ggI_expanded, at::ones({}, sigma2_eps_neg_1_2.options()));
gG = (gO * gG).sum(0);
gG = gG.reshape_as(*gamma);
}

// calculate ggO
Tensor ggO;
// contribution of input term
if (ggI.defined()) {
ggO = first_bwd_fn_grad_input(ggI_expanded, gamma_expanded);
}
if (ggG.defined()) {
auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
}
if (ggB.defined()) {
auto ggO_B_term = ggB_expanded;
ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
}
if (ggO.defined()) {
ggO = ggO.expand({M, N}).reshape_as(input_t);
}

if (output_mask[1] && !gG.defined()) {
AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
}

return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
}

std::tuple<Tensor, Tensor, Tensor>
Expand Down
22 changes: 11 additions & 11 deletions torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,18 +223,18 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
const Tensor & weight_);
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_);
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask);
std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_layer_norm_backward(
const Tensor& dY,
const Tensor& dmean,
const Tensor& drstd,
const Tensor& X,
const Tensor& mean,
const Tensor& rstd,
const c10::optional<Tensor>& gamma,
std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
const Tensor & input,
const c10::optional<Tensor> & gamma,
const Tensor & ggI,
const Tensor & ggG,
const Tensor & ggB,
const Tensor & gO,
const Tensor & save_mean,
const Tensor & save_invstd,
IntArrayRef normalized_shape,
double eps,
std::array<bool, 3> grad_input_mask);
std::array<bool,3> output_mask);

std::tuple<Tensor, Tensor> householder_product_backward(const Tensor& grad, const Tensor& input, const Tensor& tau);
std::tuple<Tensor, Tensor> polar_backward(
const Tensor& grad,
Expand Down
55 changes: 5 additions & 50 deletions torch/csrc/jit/runtime/symbolic_script.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,64 +1141,19 @@ const std::vector<std::string> functions = {
return output, backward
# disable the layernorm AD temporarily because of bug in https://github.com/pytorch/pytorch/issues/19769
def layer_norm_disabled(input : Tensor,
def layer_norm(input : Tensor,
normalized_shape : List[int],
weight : Optional[Tensor],
bias : Optional[Tensor],
eps : float,
cudnn_enable : bool):
input_ndim = input.dim()
normalized_ndim = len(normalized_shape)
n = 1
for i in range(input_ndim - normalized_ndim):
n *= input.size(i)
input_reshape = input.contiguous().view(1, n, -1)
bn_out, save1, save2, reserve, impl_idx = torch._batch_norm_impl_index(
input_reshape, None, None, None, None, True,
0.0, eps, cudnn_enable)
bn_out = bn_out.view(input.size())
if weight is not None and bias is not None:
output = bias.addcmul(bn_out, weight, value=1)
elif weight is not None:
output = bn_out.mul(weight)
elif bias is not None:
output = bn_out.add(bias)
else:
output = bn_out
def backward(grad_output):
if weight is not None and bias is not None:
grad_bn_out = grad_output * weight
grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
grad_bias = grad_output._grad_sum_to_size(bias.size())
elif weight is not None:
grad_bn_out = grad_output * weight
grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
grad_bias = None
elif bias is not None:
grad_bn_out = grad_output
grad_weight= None
grad_bias = grad_output._grad_sum_to_size(bias.size())
else:
grad_bn_out = grad_output
grad_weight= None
grad_bias = None
grad_bn_out = grad_bn_out.contiguous().view(1, n, -1)
output, mean, rstd = torch.native_layer_norm(input, normalized_shape, weight, bias, eps)
grad_input, _, _ = torch._batch_norm_impl_index_backward(
impl_idx, input_reshape, grad_bn_out, None, None, None,
save1, save2, True, eps, [True, False, False], reserve)
grad_input = grad_input.view(input.size())
def backward(grad_output):
output_mask = [True, weight is not None, bias is not None]
grad_input, grad_weight, grad_bias = torch.native_layer_norm_backward(grad_output, input, normalized_shape, mean, rstd, weight, bias, output_mask)
return grad_input, None, grad_weight, grad_bias, None, None
return output, backward
def AD_fused_dropout_backward(grad,
Expand Down
8 changes: 4 additions & 4 deletions torch/testing/_internal/jit_metaprogramming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@
'', (False, 'aten::_batch_norm_impl_index')),
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
('layer_norm', (S, S, S, S), ([5],), '',
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
(True, ['aten::native_layer_norm'])),
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
(True, ['aten::native_layer_norm'])),
('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
(True, ['aten::native_layer_norm'])),
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
non_differentiable(torch.rand(S))), 'with_weight_and_bias',
(False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
(True, ['aten::native_layer_norm'])),
('group_norm', (S, S, S), (1, torch.rand(5),),),
('local_response_norm', (S, S, S), (2, ),),
('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),
Expand Down

0 comments on commit ed0b8a3

Please sign in to comment.