Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#15844 from panyx0718/infer
Browse files Browse the repository at this point in the history
add per kernel config and remove const_cast.
  • Loading branch information
panyx0718 authored Feb 25, 2019
2 parents dec9cf5 + 5dd281f commit 44e7fcd
Show file tree
Hide file tree
Showing 14 changed files with 251 additions and 192 deletions.
18 changes: 16 additions & 2 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,16 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this->InferShape(&infer_shape_ctx);
}

std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
const OpKernelType& key) const {
auto config_iter = kernel_configs_map_.find(key);
std::vector<KernelConfig>* kernel_configs = nullptr;
if (config_iter != kernel_configs_map_.end()) {
kernel_configs = &(config_iter->second);
}
return kernel_configs;
}

void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
RuntimeContext ctx(Inputs(), Outputs(), scope);
Expand All @@ -921,7 +931,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second;

auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx));
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;

auto kernel_iter = kernels.find(expected_kernel_key);
Expand All @@ -940,6 +950,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
KernelTypeToString(expected_kernel_key));
}

std::vector<KernelConfig>* kernel_configs =
GetKernelConfig(expected_kernel_key);

// do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope =
Expand All @@ -957,7 +970,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
this->InferShape(&infer_shape_ctx);
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx));
kernel_iter->second(
ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs));

if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered.
Expand Down
37 changes: 35 additions & 2 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
Expand Down Expand Up @@ -184,12 +185,30 @@ class OperatorBase {
const platform::Place& place) const = 0;
};

#ifdef PADDLE_WITH_CUDA
using KernelConfig = boost::variant<
std::shared_ptr<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>>;
#else
using KernelConfig = boost::variant<boost::blank>;
#endif

using OpKernelConfigsMap =
std::unordered_map<OpKernelType, std::vector<KernelConfig>,
OpKernelType::Hash>;

class ExecutionContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context,
const RuntimeContext& ctx)
: op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {}
const RuntimeContext& ctx,
std::vector<KernelConfig>* configs)
: op_(op),
scope_(scope),
device_context_(device_context),
ctx_(ctx),
kernel_configs_(configs) {}

const OperatorBase& op() const { return op_; }

Expand Down Expand Up @@ -398,11 +417,20 @@ class ExecutionContext {
return temp_tensor;
}

template <typename T>
T& GetKernelConfig(int idx) const {
PADDLE_ENFORCE(kernel_configs_ && kernel_configs_->size() > idx,
"%s selected kernel doesn't have kernel config %lu <= %d",
op_.Type().c_str(), kernel_configs_->size(), idx);
return *boost::get<std::shared_ptr<T>>(kernel_configs_->at(idx));
}

private:
const OperatorBase& op_;
const Scope& scope_;
const platform::DeviceContext& device_context_;
const RuntimeContext& ctx_;
mutable std::vector<KernelConfig>* kernel_configs_;
};

template <>
Expand Down Expand Up @@ -483,6 +511,8 @@ class OperatorWithKernel : public OperatorBase {

virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;

std::vector<KernelConfig>* GetKernelConfig(const OpKernelType& key) const;

protected:
virtual OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
Expand All @@ -508,6 +538,9 @@ class OperatorWithKernel : public OperatorBase {
void TransferInplaceVarsBack(const Scope& scope,
const std::vector<std::string>& inplace_vars,
const Scope& exec_scope) const;

protected:
mutable OpKernelConfigsMap kernel_configs_map_;
};

extern bool OpSupportGPU(const std::string& op_type);
Expand Down
118 changes: 118 additions & 0 deletions paddle/fluid/framework/operator_kernel_configs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <algorithm>
#include <unordered_map>
#include <vector>

namespace paddle {
namespace framework {

// Not thread-safe. Should be owned per-kernel.
template <typename TAlgorithm>
class AlgorithmsCache {
public:
AlgorithmsCache() : search_times_(0) { hash_.clear(); }
// Caches the best algorithm for a given
// combination of tensor dimensions & compute data type.
TAlgorithm GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations,
int algorithmFlags, // can set for different data type
std::function<TAlgorithm()> gen_func);

TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func);

private:
std::unordered_map<int64_t, TAlgorithm> hash_;
int search_times_;
};

template <typename TAlgorithm>
TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
int64_t seed = 0;
// Hash all of the inputs, use to try and look up a previously
// discovered algorithm, or fall back to generating a new one.
std::hash<int64_t> hashFn;
// do hash like boost
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
for (const auto num : dims1) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

for (const auto num : dims2) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1;
}

for (const auto num : strides) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 2;
}

for (const auto num : paddings) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 3;
}

for (const auto num : dilations) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 4;
}

seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 5;

if (seed == 0) return gen_func();

if (hash_.find(seed) == hash_.end()) {
TAlgorithm value = gen_func();
hash_[seed] = value;
}
return hash_[seed];
}

template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
if (hash_.find(area) != hash_.end()) {
return hash_[area];
}
if (search_times_ < search_times) {
auto algo = gen_func();
hash_[area] = algo;
++search_times_;
return algo;
}
TAlgorithm algo;
int64_t min = static_cast<uint64_t>(INT_MAX);
for (const auto& m : hash_) {
if (m.first < min) {
min = m.first;
algo = m.second;
}
}
return algo;
}

} // namespace framework
} // namespace paddle
5 changes: 0 additions & 5 deletions paddle/fluid/framework/var_type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ class Scope;
} // namespace framework

namespace operators {
template <typename T>
class AlgorithmsCache;

class CudnnRNNCache;

Expand Down Expand Up @@ -144,9 +142,6 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#ifndef _WIN32
ncclUniqueId, platform::Communicator,
#endif
operators::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>,
operators::CudnnRNNCache,
#endif
int, float>;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/imperative/layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
p.func(
framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
}
}

Expand Down
19 changes: 14 additions & 5 deletions paddle/fluid/imperative/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ class PreparedOp {
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
framework::OperatorWithKernel::OpKernelFunc func,
platform::DeviceContext* dev_ctx)
: op(op), ctx(ctx), func(func), dev_ctx(dev_ctx) {}
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs)
: op(op),
ctx(ctx),
func(func),
dev_ctx(dev_ctx),
kernel_configs(kernel_configs) {}

static PreparedOp Prepare(const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel& op,
Expand All @@ -64,8 +69,9 @@ class PreparedOp {

framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second;

auto expected_kernel_key = op.GetExpectedKernelType(
framework::ExecutionContext(op, framework::Scope(), *dev_ctx, ctx));
auto expected_kernel_key =
op.GetExpectedKernelType(framework::ExecutionContext(
op, framework::Scope(), *dev_ctx, ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;

auto kernel_iter = kernels.find(expected_kernel_key);
Expand All @@ -83,7 +89,9 @@ class PreparedOp {
PADDLE_THROW("op %s does not have kernel for %s", op.Type(),
KernelTypeToString(expected_kernel_key));
}
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx);
std::vector<framework::KernelConfig>* kernel_configs =
op.GetKernelConfig(expected_kernel_key);
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs);
}

inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx; }
Expand All @@ -92,6 +100,7 @@ class PreparedOp {
const framework::RuntimeContext& ctx;
framework::OperatorWithKernel::OpKernelFunc func;
platform::DeviceContext* dev_ctx;
std::vector<framework::KernelConfig>* kernel_configs;
};

class OpBase;
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->place_ = GetExpectedPlace(expected_place, inputs);
PreparedOp prepared_op = PreparedOp::Prepare(ctx, *op_kernel, op->place_);
prepared_op.op.RuntimeInferShape(scope, op->place_, ctx);
prepared_op.func(framework::ExecutionContext(
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
prepared_op.func(
framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx,
prepared_op.ctx, prepared_op.kernel_configs));

if (!stop_gradient) {
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/beam_search_decode_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
auto& dev_ctx = *pool.Get(dev_place);

framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx, nullptr);

const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
Expand Down
Loading

0 comments on commit 44e7fcd

Please sign in to comment.