Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim] Add index_add_double_grad in dygraph composite #70367

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
"gather_nd_double_grad",
"reshape_double_grad",
"take_along_axis_double_grad",
"index_add_double_grad",
]

# white ops list whose kernel can automatically do type promotion.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@
'index_put_grad',
'gather_nd_grad',
'take_along_axis_grad',
'index_add_grad',
]
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
- sign
- sigmoid
- index_put
- index_add
- isnan
- isfinite
- take_along_axis
Original file line number Diff line number Diff line change
Expand Up @@ -1107,5 +1107,43 @@ void take_along_axis_double_grad(const Tensor& indices,
}
}

template <typename T>
void index_add_double_grad(const Tensor& index,
const Tensor& out_grad,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_add_value_grad,
int axis,
Tensor* grad_out_grad) {
if (grad_out_grad) {
if (grad_x_grad && grad_add_value_grad) {
Tensor grad_out_grad_tmp = grad_x_grad.get();
grad_out_grad_tmp = index_add<T>(
grad_out_grad_tmp, index, grad_add_value_grad.get(), axis);
set_output<T>(grad_out_grad_tmp, grad_out_grad);

} else if (grad_x_grad) {
Tensor grad_out_grad_tmp = grad_x_grad.get();
set_output<T>(grad_out_grad_tmp, grad_out_grad);

} else if (grad_add_value_grad) {
Tensor DDadd_value = grad_add_value_grad.get();
Tensor grad_out_grad_tmp = full<T>(common::vectorize(out_grad.dims()),
0,
DDadd_value.dtype(),
DDadd_value.place());
grad_out_grad_tmp =
index_add<T>(grad_out_grad_tmp, index, DDadd_value, axis);
set_output<T>(grad_out_grad_tmp, grad_out_grad);

} else {
Tensor grad_out_grad_tmp = full<T>(common::vectorize(out_grad.dims()),
0,
out_grad.dtype(),
out_grad.place());
set_output<T>(grad_out_grad_tmp, grad_out_grad);
}
}
}

} // namespace prim
} // namespace paddle
24 changes: 14 additions & 10 deletions paddle/phi/kernels/cpu/index_add_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,21 @@ void IndexAddGradKernel(const Context& ctx,
phi::DataType::INT64));

// get x_grad: copy out_grad to x_grad.
ctx.template Alloc<T>(x_grad);
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
if (x_grad) {
ctx.template Alloc<T>(x_grad);
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
}

auto inputs = out_grad;
// get add_value_grad by using index_select(out_grad, index, axis)
if (index_type == phi::DataType::INT32) {
IndexSelectInner<Context, T, int>(
ctx, &inputs, index, add_value_grad, axis);
} else if (index_type == phi::DataType::INT64) {
IndexSelectInner<Context, T, int64_t>(
ctx, &inputs, index, add_value_grad, axis);
if (add_value_grad) {
auto inputs = out_grad;
// get add_value_grad by using index_select(out_grad, index, axis)
if (index_type == phi::DataType::INT32) {
IndexSelectInner<Context, T, int>(
ctx, &inputs, index, add_value_grad, axis);
} else if (index_type == phi::DataType::INT64) {
IndexSelectInner<Context, T, int64_t>(
ctx, &inputs, index, add_value_grad, axis);
}
}
}

Expand Down
70 changes: 36 additions & 34 deletions paddle/phi/kernels/gpu/index_add_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,9 @@ void IndexAddGradKernel(const Context& ctx,
int dim,
DenseTensor* x_grad,
DenseTensor* add_value_grad) {
auto* output_grad_data = out_grad.data<T>();
auto* in_grad_data = ctx.template Alloc<T>(x_grad);
auto* add_value_grad_data = ctx.template Alloc<T>(add_value_grad);

auto input_dim = x_grad->dims();
auto output_dim = out_grad.dims();
auto add_value_dim = add_value_grad->dims();
// x.shape == out.shape in index_grad op
auto input_dim = out_grad.dims();
auto add_value_dim = add_value.dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = common::stride(input_dim);
int64_t stride = stride_dim[dim];
Expand All @@ -59,42 +55,48 @@ void IndexAddGradKernel(const Context& ctx,
phi::DataType::INT32,
phi::DataType::INT64));

int64_t numel = add_value_grad->numel();
int64_t numel = add_value.numel();
if (numel == 0) {
return;
}
auto stream = ctx.stream();

// get x_grad: copy out_grad to x_grad.
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
if (x_grad) {
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
}

// get add_value_grad: index_select(out_grad, index, axis)
unsigned int block_dim = PADDLE_CUDA_NUM_THREADS;
dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim);
phi::backends::gpu::LimitGridDim(ctx, &grid_dim);
if (add_value_grad) {
auto* output_grad_data = out_grad.data<T>();
auto* add_value_grad_data = ctx.template Alloc<T>(add_value_grad);
unsigned int block_dim = PADDLE_CUDA_NUM_THREADS;
dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim);
phi::backends::gpu::LimitGridDim(ctx, &grid_dim);

if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
index_select_cuda_kernel<T, int64_t>
<<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
add_value_grad_data,
index_data,
numel,
stride,
size,
delta,
input_dim[dim]);
} else {
const int* index_data = index.data<int>();
index_select_cuda_kernel<T, int>
<<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
add_value_grad_data,
index_data,
numel,
stride,
size,
delta,
input_dim[dim]);
if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
index_select_cuda_kernel<T, int64_t>
<<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
add_value_grad_data,
index_data,
numel,
stride,
size,
delta,
input_dim[dim]);
} else {
const int* index_data = index.data<int>();
index_select_cuda_kernel<T, int>
<<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
add_value_grad_data,
index_data,
numel,
stride,
size,
delta,
input_dim[dim]);
}
}
}

Expand Down
13 changes: 11 additions & 2 deletions paddle/phi/kernels/xpu/index_add_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,17 @@ void IndexAddGradKernel(const Context& ctx,
int dim,
DenseTensor* x_grad,
DenseTensor* add_value_grad) {
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
phi::IndexSelectKernel<T, Context>(ctx, out_grad, index, dim, add_value_grad);
if (dim < 0) {
dim += out_grad.dims().size();
}

if (x_grad) {
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
}
if (add_value_grad) {
phi::IndexSelectKernel<T, Context>(
ctx, out_grad, index, dim, add_value_grad);
}
}

} // namespace phi
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,18 @@
func : imag_grad
data_type : complex(out_grad)

- backward_op : index_add_double_grad
forward : index_add_grad (Tensor index, Tensor add_value, Tensor grad_out, int axis) -> Tensor(grad_x), Tensor(grad_add_value)
args : (Tensor index, Tensor grad_out, Tensor grad_x_grad, Tensor grad_add_value_grad, int axis)
output : Tensor(grad_out_grad)
infer_meta :
func : UnchangedInferMeta
param: [grad_out]
data_transform :
skip_transform : index
composite : index_add_double_grad(index, grad_out, grad_x_grad, grad_add_value_grad, axis, grad_out_grad)
optional: grad_x_grad, grad_add_value_grad

- backward_op : index_add_grad
forward : index_add(Tensor x, Tensor index, Tensor add_value, int axis=0) -> Tensor(out)
args : (Tensor index, Tensor add_value, Tensor out_grad, int axis)
Expand All @@ -1575,6 +1587,7 @@
func : index_add_grad
data_type : out_grad
inplace : (out_grad -> x_grad)
backward : index_add_double_grad

- backward_op : index_put_double_grad
forward : index_put_grad (Tensor x, Tensor[] indices, Tensor value, Tensor grad_out, bool accumulate=false) -> Tensor(grad_x), Tensor(grad_value)
Expand Down
12 changes: 9 additions & 3 deletions test/legacy_test/test_index_add_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,20 +258,26 @@ def run_imperative(self, device):

if self.check_backward:
dout_tensor = paddle.to_tensor(self.dout_np)
paddle.autograd.backward([out], [dout_tensor], retain_graph=True)
(input_tensor_grad,) = paddle.autograd.grad(
[out], [input_tensor], dout_tensor
)
(add_value_grad,) = paddle.autograd.grad(
[out], [add_value], dout_tensor
)

(
ref_x_grad,
ref_add_value_grad,
) = self.compute_index_add_backward_ref()
np.testing.assert_allclose(
ref_x_grad,
input_tensor.grad.numpy(),
input_tensor_grad.numpy(),
rtol=self.rtol,
atol=self.atol,
)
np.testing.assert_allclose(
ref_add_value_grad,
add_value.grad.numpy(),
add_value_grad.numpy(),
rtol=self.rtol,
atol=self.atol,
)
Expand Down
58 changes: 58 additions & 0 deletions test/prim/prim/vjp/eager/test_comp_eager_index_add_double_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import parameterized as param

import paddle


@param.parameterized_class(
('x', 'index', 'axis', 'value', 'cotangent', 'dtype'),
[
(
np.random.randn(4, 3, 2), # x
np.random.randint(-3, 3, size=(16,)), # index
1, # axis
np.random.randint(0, 3, size=(4, 16, 2)), # valie
np.random.rand(4, 3, 2), # cotangent
np.float32, # dtype
),
],
)
class TestTakeAlongAxisTanhDoubleGrad(unittest.TestCase):
def test_index_add_tanh_double_grad(self):
x_tensor = paddle.to_tensor(
self.x, dtype=self.dtype, stop_gradient=False
)
value_tensor = paddle.to_tensor(
self.value, dtype=self.dtype, stop_gradient=False
)
index_tensor = paddle.to_tensor(self.index, dtype="int64")
dout_tensor = paddle.to_tensor(
self.cotangent, dtype=self.dtype, stop_gradient=False
)
out = paddle.index_add(x_tensor, index_tensor, self.axis, value_tensor)

out = paddle.tanh(out)

dx = paddle.grad(out, x_tensor, dout_tensor, create_graph=True)[0]

ddx = paddle.grad(dx, dout_tensor, create_graph=True)[0]


if __name__ == '__main__':
unittest.main()