Skip to content

Commit

Permalink
add Cosine Similarity Backward function.
Browse files Browse the repository at this point in the history
  • Loading branch information
xutianbing committed Feb 8, 2017
1 parent 9ee7236 commit ccac20d
Show file tree
Hide file tree
Showing 4 changed files with 352 additions and 3 deletions.
122 changes: 122 additions & 0 deletions paddle/function/CosSimOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,130 @@ class CosSimForwardFunc : public FunctionBase {
real scale_;
};

template <>
void CosSimBackward<DEVICE_TYPE_CPU>(const CpuMatrix* out_grad,
const CpuMatrix* out_val,
const CpuMatrix* in1_val,
const CpuMatrix* in2_val,
CpuMatrix* in1_grad,
CpuMatrix* in2_grad,
real scale) {
CHECK(out_grad && out_val && in1_val && in2_val && in1_grad && in2_grad);
CHECK_EQ(out_val->useGpu_, false) << "Matrix type are GPU, CPU required";

const real* grad = out_grad->getData();
const real* out = out_val->getData();
const real* prev_out_x = in1_val->getData();
const real* prev_out_y = in2_val->getData();
real* prev_grad_x = in1_grad->getData();
real* prev_grad_y = in2_grad->getData();

size_t num_samples = out_grad->getHeight();
size_t dim = in1_val->getWidth();
CHECK_EQ(in2_val->getHeight(), in2_grad->getHeight());
CHECK(in2_val->getHeight() == 1LU || in2_val->getHeight() == num_samples);
size_t inc = (in2_val->getHeight() == 1LU) ? 0 : dim;
for (size_t i = 0; i < num_samples; ++i,
prev_out_x += dim,
prev_out_y += inc,
prev_grad_x += dim,
prev_grad_y += inc) {
real square_sum_x = 0;
real square_sum_y = 0;
real xy = 0;
for (size_t j = 0; j < dim; ++j) {
square_sum_x += prev_out_x[j] * prev_out_x[j];
square_sum_y += prev_out_y[j] * prev_out_y[j];
xy += prev_out_x[j] * prev_out_y[j];
}
CHECK(square_sum_x > 0 && square_sum_y > 0);
if (xy == 0) {
real reciprocal =
1.0f / (std::sqrt(square_sum_x) * std::sqrt(square_sum_y));
for (size_t j = 0; j < dim; ++j) {
prev_grad_x[j] += scale * grad[i] * prev_out_y[j] * reciprocal;
prev_grad_y[j] += scale * grad[i] * prev_out_x[j] * reciprocal;
}
} else {
real reciprocal_xy = 1.0f / xy;
real reciprocal_square_sum_x = 1.0f / square_sum_x;
real reciprocal_square_sum_y = 1.0f / square_sum_y;
for (size_t j = 0; j < dim; ++j) {
prev_grad_x[j] +=
out[i] * grad[i] * (prev_out_y[j] * reciprocal_xy -
prev_out_x[j] * reciprocal_square_sum_x);
prev_grad_y[j] +=
out[i] * grad[i] * (prev_out_x[j] * reciprocal_xy -
prev_out_y[j] * reciprocal_square_sum_y);
}
}
}
}

/**
* \param inputs[0] output value 1, size: nSamples * 1.
* \param inputs[1] input value 1, size: nSamples * dim.
* \param inputs[2] input value 2, size: n2 * dim (n2 == 1 or n2 == nSamples).
* \param inputs[3] input grad 1, size: nSamples * dim.
* \param inputs[4] input grad 2, size: n2 * dim (n2 == 1 or n2 == nSamples).
* \param outputs[0] output grad, size : nSamples * 1.
*/
template <DeviceType Device>
class CosSimBackwardFunc : public FunctionBase {
void init(const FuncConfig& config) override {
scale_ = config.get<real>("scale");
}

void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(inputs.size(), 5);
CHECK_EQ(outputs.size(), 1);
CHECK_EQ(inouts.size(), 0);
/// dim of out_grad and out_val == 1, column vector
CHECK_EQ(outputs[0].dims_[1], 1UL);
CHECK_EQ(inputs[0].dims_[1], 1UL);
/// nSamples of out_grad == out_val == in_val1 == in_grad1
CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]);
CHECK_EQ(inputs[1].dims_[0], outputs[0].dims_[0]);
CHECK_EQ(inputs[3].dims_[0], outputs[0].dims_[0]);
/// dim of in1_val1 == in_val2 == in_grad1 == in_grad2
CHECK_EQ(inputs[2].dims_[1], inputs[1].dims_[1]);
CHECK_EQ(inputs[3].dims_[1], inputs[1].dims_[1]);
CHECK_EQ(inputs[4].dims_[1], inputs[1].dims_[1]);

CHECK(outputs[0].getData() && inputs[0].getData() && inputs[1].getData() &&
inputs[2].getData() && inputs[3].getData() && inputs[4].getData());
const auto out_grad = std::make_shared<typename MatrixT<Device>::type>(
outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]);
const auto out_val = std::make_shared<typename MatrixT<Device>::type>(
inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]);
const auto in1_val = std::make_shared<typename MatrixT<Device>::type>(
inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]);
const auto in2_val = std::make_shared<typename MatrixT<Device>::type>(
inputs[2].getData(), inputs[2].dims_[0], inputs[2].dims_[1]);
auto in1_grad = std::make_shared<typename MatrixT<Device>::type>(
inputs[3].getData(), inputs[3].dims_[0], inputs[3].dims_[1]);
auto in2_grad = std::make_shared<typename MatrixT<Device>::type>(
inputs[4].getData(), inputs[4].dims_[0], inputs[4].dims_[1]);

CosSimBackward<Device>(out_grad.get(),
out_val.get(),
in1_val.get(),
in2_val.get(),
in1_grad.get(),
in2_grad.get(),
scale_);
}

private:
real scale_;
};

REGISTER_TYPED_FUNC(CosSimForward, CPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, CPU, CosSimBackwardFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(CosSimForward, GPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, GPU, CosSimBackwardFunc);
#endif
} // namespace paddle
21 changes: 21 additions & 0 deletions paddle/function/CosSimOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,25 @@ void CosSimForward(typename MatrixT<Device>::type* output,
const typename MatrixT<Device>::type* input2,
real scale);

/**
* \brief Cosine Similarity BackWard for Derivative.
*
* \param[out] output1 backward loss output grad.
* \param[in] input1 forward-output value.
* \param[in] input2 forward input value 1.
* \param[in] input3 forward input value 2.
* \param[in] input4 forward input grad 1.
* \param[in] input5 forward input grad 2.
* \param[in] scale default 1.0.
*
*/
template <DeviceType Device>
void CosSimBackward(const typename MatrixT<Device>::type* out_grad,
const typename MatrixT<Device>::type* out_value,
const typename MatrixT<Device>::type* in1_value,
const typename MatrixT<Device>::type* in2_value,
typename MatrixT<Device>::type* in1_grad,
typename MatrixT<Device>::type* in2_grad,
real scale);

} // namespace paddle
142 changes: 140 additions & 2 deletions paddle/function/CosSimOpGpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_base.h"
#include "hl_device_functions.cuh"
#include "CosSimOp.h"

namespace paddle {
Expand Down Expand Up @@ -79,7 +80,7 @@ void hlCossim(real* output,

KeCosSim<block_size><<<grid, threads, 0, STREAM_DEFAULT>>>
(output, input1, input2, width, input1_height, input2_height, scale);
CHECK_SYNC("hl_cossim failed");
CHECK_SYNC("hlCossim failed");
}

template <>
Expand All @@ -91,12 +92,149 @@ void CosSimForward<DEVICE_TYPE_GPU>(GpuMatrix* out_mat,
CHECK(in1_mat->useGpu_ == true && in2_mat->useGpu_ == true)
<< "Matrix type are not GPU";

size_t numSamples = out_mat->getHeight();
size_t num_samples = out_mat->getHeight();
size_t dim = in1_mat->getWidth();
real* out = out_mat->getData();
const real* x = in1_mat->getData();
const real* y = in2_mat->getData();
hlCossim(out, x, y, dim, in1_mat->getHeight(), in2_mat->getHeight(), scale);
}

template<int block_size>
__global__ void KeCosSimDerivative(const real* grad,
const real* output,
const real* prev_out_x,
const real* prev_out_y,
real* prev_grad_x,
real* prev_grad_y,
size_t width,
size_t input1_height,
size_t input2_height,
real scale) {
const int ty = blockIdx.y;
int tid = threadIdx.x;

__shared__ real xx[block_size];
__shared__ real yy[block_size];
__shared__ real xy[block_size];

xx[tid] = 0.0;
yy[tid] = 0.0;
xy[tid] = 0.0;
__syncthreads();

prev_out_x += ty * width;
prev_grad_x += ty * width;
if (input2_height > 1) {
prev_out_y += ty * width;
prev_grad_y += ty * width;
}
for (int index = tid; index < width; index += block_size) {
real x = prev_out_x[index];
real y = prev_out_y[index];
xx[tid] += x * x;
yy[tid] += y * y;
xy[tid] += x * y;
}
__syncthreads();

for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
xx[tid] += xx[tid + s];
yy[tid] += yy[tid + s];
xy[tid] += xy[tid + s];
}
__syncthreads();
}
if (xy[0] == 0) {
real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0]));
for (int index = tid; index < width; index += block_size) {
prev_grad_x[index] +=
scale * grad[ty] * prev_out_y[index] * reciprocal;
if (input2_height > 1) {
prev_grad_y[index] +=
scale * grad[ty] * prev_out_x[index] * reciprocal;
} else {
paddle::paddleAtomicAdd(prev_grad_y + index,
scale * grad[ty] * prev_out_x[index] * reciprocal);
}
}
} else {
real reciprocalXY = 1.0 / xy[0];
real reciprocalSquareSumX = 1.0 / xx[0];
real reciprocalSquareSumY = 1.0 / yy[0];
for (int index = tid; index < width; index += block_size) {
prev_grad_x[index] += output[ty] * grad[ty] *
(prev_out_y[index] * reciprocalXY -
prev_out_x[index] * reciprocalSquareSumX);
if (input2_height > 1) {
prev_grad_y[index] += output[ty] * grad[ty] *
(prev_out_x[index] * reciprocalXY -
prev_out_y[index] * reciprocalSquareSumY);
} else {
paddle::paddleAtomicAdd(prev_grad_y + index, output[ty] * grad[ty] *
(prev_out_x[index] * reciprocalXY -
prev_out_y[index] * reciprocalSquareSumY));
}
}
}
}

void hlCossimDerivative(const real* grad,
const real* output,
const real* prev_out_x,
const real* prev_out_y,
real* prev_grad_x,
real* prev_grad_y,
size_t width,
size_t input1_height,
size_t input2_height,
real scale) {
CHECK_NOTNULL(grad);
CHECK_NOTNULL(output);
CHECK_NOTNULL(prev_out_x);
CHECK_NOTNULL(prev_out_y);
CHECK_NOTNULL(prev_grad_x);
CHECK_NOTNULL(prev_grad_y);
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, input1_height);
KeCosSimDerivative<block_size><<<grid, threads, 0, STREAM_DEFAULT>>>
(grad, output, prev_out_x, prev_out_y, prev_grad_x, prev_grad_y, width,
input1_height, input2_height, scale);
CHECK_SYNC("hlCossimDerivate failed");
}

template <>
void CosSimBackward<DEVICE_TYPE_GPU>(const GpuMatrix* out_grad,
const GpuMatrix* out_val,
const GpuMatrix* in1_val,
const GpuMatrix* in2_val,
GpuMatrix* in1_grad,
GpuMatrix* in2_grad,
real scale) {
CHECK(out_grad && out_val && in1_val && in2_val && in1_grad && in2_grad);
CHECK(out_grad->useGpu_ && out_val->useGpu_ && in1_val->useGpu_
&& in2_val->useGpu_ && in1_grad->useGpu_ && in2_grad->useGpu_)
<< "Matrix types are not equally GPU";

size_t dim = in1_val->getWidth();
const real* grad = out_grad->getData();
const real* out = out_val->getData();
const real* prev_out_x = in1_val->getData();
const real* prev_out_y = in2_val->getData();
real* prev_grad_x = in1_grad->getData();
real* prev_grad_y = in2_grad->getData();
hlCossimDerivative(grad,
out,
prev_out_x,
prev_out_y,
prev_grad_x,
prev_grad_y,
dim,
in1_val->getHeight(),
in2_val->getHeight(),
scale);
}

} // namespace paddle
Loading

0 comments on commit ccac20d

Please sign in to comment.