Skip to content

Commit

Permalink
polish cudnn related code and fix bug. (PaddlePaddle#15164)
Browse files Browse the repository at this point in the history
* staged.

* polish code

* polish code. test=develop

* polish code. test=develop

* api change. test=develop

* fix default value. test=develop

* fix default value. test=develop
  • Loading branch information
dzhwinter authored and ceci3 committed Mar 4, 2019
1 parent 8e094f7 commit 4449e85
Show file tree
Hide file tree
Showing 11 changed files with 543 additions and 128 deletions.
4 changes: 4 additions & 0 deletions cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ function(op_library TARGET)
# pybind USE_OP_DEVICE_KERNEL for CUDNN
list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0)
if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n")
else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
endif()
endif()

# pybind USE_OP_DEVICE_KERNEL for MIOPEN
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/distributed.h"
Expand Down
40 changes: 40 additions & 0 deletions paddle/fluid/operators/activation_cudnn.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/cudnn_desc.h"

namespace paddle {
namespace operators {
using framework::Tensor;
using platform::ActivationDescriptor;
using platform::TensorDescriptor;

template <typename Functor>
class CudnnActivationKernel
: public framework::OpKernel<Functor::ElEWISE_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
framework::Tensor *X, *Out;
ExtractActivationTensor(context, X, Out);
ActivationDescriptor act_desc;
TensorDescriptor x_desc, out_desc;
x_desc.set(detail::Ref(X));
out_desc.set(detail::Ref(Out));
}
};

} // namespace operators
} // namespace paddle
175 changes: 175 additions & 0 deletions paddle/fluid/operators/activation_cudnn_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/cudnn_desc.h"

namespace paddle {
namespace operators {
using framework::Tensor;
using platform::ActivationDescriptor;
using platform::TensorDescriptor;
using platform::CUDADeviceContext;

template <typename T>
struct CudnnActivationFunctor {
using ELEMENT_TYPE = T;
CudnnActivationFunctor(const CUDADeviceContext& ctx, const T& c,
const cudnnActivationMode_t& m)
: ctx_(ctx), coef_(c), mode_(m) {}
void operator()(const Tensor& x, Tensor* out) {
ActivationDescriptor act_desc;
act_desc.set(mode_, coef_);
TensorDescriptor x_desc, out_desc;
x_desc.set(x);
out_desc.set(detail::Ref(out));
PADDLE_ENFORCE(platform::dynload::cudnnActivationForward(
ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
platform::CudnnDataType<T>::kZero(), out_desc.desc(),
out->mutable_data<T>(ctx_.GetPlace())));
}
const CUDADeviceContext& ctx_;
const T coef_;
const cudnnActivationMode_t mode_;
};

template <typename T>
struct CudnnActivationGradFunctor {
using ELEMENT_TYPE = T;
CudnnActivationGradFunctor(const CUDADeviceContext& ctx, const T& c,
const cudnnActivationMode_t& m)
: ctx_(ctx), coef_(c), mode_(m) {}
void operator()(const Tensor& x, const Tensor& out, const Tensor dout,
Tensor* dx) {
ActivationDescriptor act_desc;
act_desc.set(mode_, coef_);
TensorDescriptor x_desc, out_desc, dout_desc, dx_desc;
x_desc.set(x);
out_desc.set(out);
dout_desc.set(dout);
dx_desc.set(detail::Ref(dx));
PADDLE_ENFORCE(platform::dynload::cudnnActivationBackward(
ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(),
platform::CudnnDataType<T>::kZero(), dx_desc.desc(),
dx->mutable_data<T>(ctx_.GetPlace())));
}
const CUDADeviceContext& ctx_;
const T coef_;
const cudnnActivationMode_t mode_;
};

template <typename T>
struct CudnnReluFunctor : public CudnnActivationFunctor<T> {
explicit CudnnReluFunctor(const CUDADeviceContext& ctx)
: CudnnActivationFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_RELU) {}
};
template <typename T>
struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_RELU) {}
};

template <typename T>
struct CudnnRelu6Functor : public CudnnActivationFunctor<T> {
explicit CudnnRelu6Functor(const CUDADeviceContext& ctx)
: CudnnActivationFunctor<T>(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) {}
};
template <typename T>
struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) {
}
};

template <typename T>
struct CudnnSigmoidFunctor : public CudnnActivationFunctor<T> {
explicit CudnnSigmoidFunctor(const CUDADeviceContext& ctx)
: CudnnActivationFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {}
};
template <typename T>
struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {}
};

template <typename T>
struct CudnnTanhFunctor : public CudnnActivationFunctor<T> {
explicit CudnnTanhFunctor(const CUDADeviceContext& ctx)
: CudnnActivationFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_TANH) {}
};
template <typename T>
struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_TANH) {}
};

template <typename Functor>
class CudnnActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* X = nullptr;
framework::Tensor* Out = nullptr;
ExtractActivationTensor(context, &X, &Out);
Out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<CUDADeviceContext>();
Functor functor(dev_ctx);
functor(detail::Ref(X), Out);
}
};

template <typename Functor>
class CudnnActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor *X, *Out, *dOut;
X = Out = dOut = nullptr;
framework::Tensor* dX = nullptr;
ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX);
dX->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<CUDADeviceContext>();
Functor functor(dev_ctx);
functor(detail::Ref(X), detail::Ref(Out), detail::Ref(dOut), dX);
}
};

} // namespace operators
} // namespace paddle

namespace plat = paddle::platform;
namespace ops = paddle::operators;

#define FOR_EACH_CUDNN_OP_FUNCTOR(__macro) \
__macro(relu, CudnnReluFunctor, CudnnReluGradFunctor); \
__macro(relu6, CudnnRelu6Functor, CudnnRelu6GradFunctor); \
__macro(sigmoid, CudnnTanhFunctor, CudnnTanhGradFunctor); \
__macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor)

#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationKernel<ops::functor<float>>, \
ops::CudnnActivationKernel<ops::functor<double>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationGradKernel<ops::grad_functor<float>>, \
ops::CudnnActivationGradKernel<ops::grad_functor<double>>);

FOR_EACH_CUDNN_OP_FUNCTOR(REGISTER_ACTIVATION_CUDNN_KERNEL);
47 changes: 30 additions & 17 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,36 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif

namespace paddle {
namespace operators {

using paddle::framework::Tensor;

#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
void Make() override { \
AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator"); \
AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \
AddAttr<bool>( \
"is_test", \
"(bool, default false) Set to true for inference only, false " \
"for training. Some layers may run faster when this is true.") \
.SetDefault(false); \
AddComment(OP_COMMENT); \
} \
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
void Make() override { \
AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator"); \
AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \
AddAttr<bool>("use_cudnn", \
"(bool, default false) Only used in cudnn kernel, need " \
"install cudnn") \
.SetDefault(false); \
AddAttr<bool>( \
"is_test", \
"(bool, default false) Set to true for inference only, false " \
"for training. Some layers may run faster when this is true.") \
.SetDefault(false); \
AddComment(OP_COMMENT); \
} \
}

#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
Expand Down Expand Up @@ -67,6 +74,12 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_CUDA
auto it1 = oper.Attrs().find("use_cudnn");
if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
Expand Down
Loading

0 comments on commit 4449e85

Please sign in to comment.