Skip to content

Commit

Permalink
fix gpu load model
Browse files Browse the repository at this point in the history
the parameters will load from CPUPlace, that will keep copying data
between CPU and GPU places.

test=develop
  • Loading branch information
Superjomn committed Nov 16, 2018
1 parent 1722678 commit 4bf6817
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/inference/analysis/argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ struct Argument {
std::vector<std::string>);

DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
DECL_ARGUMENT_FIELD(tensorrt_node_teller, TensorRtNodeTeller,
std::function<bool(const framework::ir::Node*)>);
Expand Down
24 changes: 18 additions & 6 deletions paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,28 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
if (!argument->scope_valid()) {
argument->SetScope(new framework::Scope);
}
PADDLE_ENFORCE(argument->use_gpu_valid());

// The load program should run on the same device with the inference program,
// so that the parameters will on the same device, or they will keep copying
// between difference devices.
platform::Place place;
if (argument->use_gpu()) {
PADDLE_ENFORCE(argument->gpu_device_id_valid());
place = platform::CUDAPlace(argument->gpu_device_id());
} else {
place = platform::CPUPlace();
}

if (argument->model_dir_valid()) {
auto program = LoadModel(argument->model_dir(), argument->scope_ptr());
auto program =
LoadModel(argument->model_dir(), argument->scope_ptr(), place);
argument->SetMainProgram(program.release());
} else if (argument->model_program_path_valid() &&
argument->model_params_path_valid()) {
auto program =
LoadModel(argument->model_program_path(), argument->model_params_path(),
argument->scope_ptr());
argument->scope_ptr(), place);
argument->SetMainProgram(program.release());
} else {
PADDLE_THROW(
Expand All @@ -52,16 +65,15 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
}

std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
const std::string &path, framework::Scope *scope) {
platform::CPUPlace place;
const std::string &path, framework::Scope *scope,
const platform::Place &place) {
framework::Executor exe(place);
return Load(&exe, scope, path);
}

std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
const std::string &program_path, const std::string &params_path,
framework::Scope *scope) {
platform::CPUPlace place;
framework::Scope *scope, const platform::Place &place) {
framework::Executor exe(place);
return Load(&exe, scope, program_path, params_path);
}
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ class IrGraphBuildPass : public AnalysisPass {
std::string repr() const override;

private:
std::unique_ptr<framework::ProgramDesc> LoadModel(const std::string &path,
framework::Scope *scope);
std::unique_ptr<framework::ProgramDesc> LoadModel(
const std::string &path, framework::Scope *scope,
const boost::variant<CUDAPlace, CPUPlace, CUDAPinnedPlace> &place);
std::unique_ptr<framework::ProgramDesc> LoadModel(
const std::string &program_path, const std::string &params_path,
framework::Scope *scope);
framework::Scope *scope,
const boost::variant<CUDAPlace, CPUPlace, CUDAPinnedPlace> &place);

std::string model_binary_str_;
};
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
status_program_optimized_ = true;

argument_.SetUseGPU(config_.use_gpu);
argument_.SetGPUDeviceId(config_.device);
// Analyze inference_program
if (!config_.model_dir.empty()) {
argument_.SetModelDir(config_.model_dir);
Expand Down Expand Up @@ -491,8 +492,7 @@ bool AnalysisPredictor::LoadParameters() {
}

// Use NaiveExecutor to Load parameters.
platform::CPUPlace place;
framework::NaiveExecutor e(place);
framework::NaiveExecutor e(place_);
e.Prepare(scope_.get(), *load_program, 0, false);
e.Run();
VLOG(3) << "get " << scope_->LocalVarNames().size() << " vars after load";
Expand Down

0 comments on commit 4bf6817

Please sign in to comment.