Skip to content

Commit

Permalink
Refine program cache (PaddlePaddle#45005)
Browse files Browse the repository at this point in the history
* add cached_serialize_str_

* support program hash

* add sha

* add ut

* use hash_str only for new_exe

* fix attr order
  • Loading branch information
zhiqiu authored Aug 13, 2022
1 parent 3f5c405 commit e96dae8
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 9 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,11 @@ cc_library(
SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc process_mesh_desc.cc
DEPS attribute shape_inference op_info operator glog version)

if(WITH_CRYPTO)
add_dependencies(proto_desc cryptopp)
target_link_libraries(proto_desc cryptopp)
endif()

cc_library(
op_registry
SRCS op_registry.cc
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/block_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,19 @@ void BlockDesc::Flush() {
var_names.emplace_back(var.name());
var_names_set.insert(var.name());
}

VLOG(4) << "vars in desc " << this->desc_->vars().size();
this->desc_->mutable_vars()->Clear();
for (const auto &name : var_names) {
if (vars_.count(name)) {
VLOG(4) << "Flush " << name;
this->desc_->mutable_vars()->Add()->CopyFrom(*vars_[name]->Proto());
vars_[name]->SetNeedUpdate(false);
}
}

for (auto &var_desc : vars_) {
if (var_names_set.count(var_desc.first) != 1) {
VLOG(4) << "Flush " << var_desc.first;
this->desc_->mutable_vars()->Add()->CopyFrom(*var_desc.second->Proto());
var_desc.second->SetNeedUpdate(false);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class BlockDesc {
// vars_

std::deque<std::unique_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::unique_ptr<VarDesc>> vars_;
std::map<std::string, std::unique_ptr<VarDesc>> vars_;

DISABLE_COPY_AND_ASSIGN(BlockDesc);
};
Expand Down
44 changes: 44 additions & 0 deletions paddle/fluid/framework/io/crypto/sha.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cryptopp/cryptlib.h>
#include <cryptopp/filters.h>
#include <cryptopp/hex.h>
#include <cryptopp/sha.h>
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {

std::string GetSha1(std::string msg) {
std::string digest;
CryptoPP::SHA1 hash;
hash.Update(reinterpret_cast<unsigned char*>(&msg.at(0)), msg.size());
digest.resize(hash.DigestSize());
hash.Final(reinterpret_cast<unsigned char*>(&digest.at(0)));
return digest;
}

std::string HexEncoding(std::string bytes) {
std::string encoded;
// Everything newed is destroyed when the StringSource is destroyed
CryptoPP::StringSource ss(
bytes, true, new CryptoPP::HexEncoder(new CryptoPP::StringSink(encoded)));
return encoded;
}

} // namespace framework
} // namespace paddle
18 changes: 17 additions & 1 deletion paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ interpreter::CostInfo InterpreterCore::DryRun(
// until the second step run.
async_work_queue_ = GetWorkQueue();

// lazy initialization of gc, do not create gc is the program only run once
if (!gc_) {
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
}

ExecuteInstructionList(vec_instruction_);
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
Expand Down Expand Up @@ -144,6 +149,12 @@ paddle::framework::FetchList InterpreterCore::Run(
// create work_queue, so the async_work_queue_ is created
// until the second step run.
async_work_queue_ = GetWorkQueue();

// lazy initialization of gc, do not create gc is the program only run once
if (!gc_) {
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
}

ExecuteInstructionList(vec_instruction_);
#ifdef PADDLE_WITH_ASCEND_CL
platform::DeviceContextPool::Instance().Get(place_)->Wait();
Expand Down Expand Up @@ -193,6 +204,11 @@ paddle::framework::FetchList InterpreterCore::Run(
// until the second step run.
async_work_queue_ = GetWorkQueue();

// lazy initialization of gc, do not create gc is the program only run once
if (!gc_) {
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
}

ExecuteInstructionList(vec_instruction_);
#ifdef PADDLE_WITH_ASCEND_CL
platform::DeviceContextPool::Instance().Get(place_)->Wait();
Expand Down Expand Up @@ -495,7 +511,7 @@ void InterpreterCore::Convert(
}

BuildSkipShareLoDInfo();
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);

bool inplaced = false;
for (auto inst : vec_instruction_) {
if (inst.OpBase()->Type() == "share_buffer" ||
Expand Down
9 changes: 8 additions & 1 deletion paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,14 @@ void OpDesc::Flush() {
}

this->desc_.mutable_attrs()->Clear();
for (auto &attr : attrs_) {
std::vector<std::pair<std::string, Attribute>> sorted_attrs{attrs_.begin(),
attrs_.end()};
std::sort(
sorted_attrs.begin(),
sorted_attrs.end(),
[](std::pair<std::string, Attribute> a,
std::pair<std::string, Attribute> b) { return a.first < b.first; });
for (auto &attr : sorted_attrs) {
auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/program_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/version.h"
#ifdef PADDLE_WITH_CRYPTO
#include "paddle/fluid/framework/io/crypto/sha.h"
#endif

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -249,6 +252,20 @@ void ProgramDesc::SetFetchHolderName(const std::string &fetch_holder_name) {
fetch_holder->SetPersistable(true);
}

std::string ProgramDesc::CachedHashString() {
std::string serialize_str;
if (cached_hash_str_.size() == 0 || NeedUpdate()) {
Flush();
desc_.SerializePartialToString(&serialize_str);
#ifdef PADDLE_WITH_CRYPTO
cached_hash_str_ = HexEncoding(GetSha1(serialize_str));
#else
cached_hash_str_ = serialize_str;
#endif
}
return cached_hash_str_;
}

bool ProgramDesc::NeedUpdate() const {
bool need = false;
for (auto &block : blocks_) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/program_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class ProgramDesc {
// This function is used to change or unify the fetch_holder variables' name.
void SetFetchHolderName(const std::string &fetch_holder_name);

std::string CachedHashString();

bool NeedUpdate() const;

private:
Expand All @@ -93,6 +95,8 @@ class ProgramDesc {
proto::ProgramDesc desc_;

std::vector<std::unique_ptr<BlockDesc>> blocks_;

std::string cached_hash_str_;
};
} // namespace framework
} // namespace paddle
10 changes: 8 additions & 2 deletions paddle/fluid/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,14 @@ void BindProgramDesc(pybind11::module *m) {
pybind11::arg("version") = pd::kCurProgramVersion)
.def("_version",
[](pd::ProgramDesc &self) -> int64_t { return self.Version(); })
.def("get_op_deps", [](const framework::ProgramDesc &program) {
return framework::ir::GetOpDependencies(program);
.def("get_op_deps",
[](const framework::ProgramDesc &program) {
return framework::ir::GetOpDependencies(program);
})
.def("need_update", &pd::ProgramDesc::NeedUpdate)
.def("cached_hash_str", [](pd::ProgramDesc &self) {
return self.CachedHashString();
// return pybind11::bytes(self.CachedHashString());
});
}

Expand Down
11 changes: 8 additions & 3 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,13 @@ def _prepare_fleet_executor():
return fleet_exe


def _get_strong_program_cache_key_for_new_exe(program, feed, fetch_list):
return program.desc.cached_hash_str() + _get_program_cache_key(
feed, fetch_list)


def _get_strong_program_cache_key(program, feed, fetch_list):
# NOTE(xiongkun) id(proram) may be duplicate. So add addition var_name as cache key.
# TODO(zhiqiu): use hash_str to generate cache key as above
def _get_varname_from_block(block):
block_str = []
for var_name in list(block.vars.keys()):
Expand Down Expand Up @@ -1455,8 +1460,8 @@ def _can_use_interpreter_core(program, place):
% (type(feed)))
feed = self._update_feed(program, feed)

key = _get_strong_program_cache_key(inner_program, feed,
fetch_list)
key = _get_strong_program_cache_key_for_new_exe(
inner_program, feed, fetch_list)

# a little bit tricy here, use inner_program before _add_feed_fetch_ops to get key
# while use program to geet _StandaloneExecutor
Expand Down
59 changes: 59 additions & 0 deletions python/paddle/fluid/tests/unittests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,5 +241,64 @@ def test_update_var_attr(self):
self.assertTrue(a == b) # not affected


class TestProgramHash(unittest.TestCase):

def build_program(self):
main_program = paddle.static.Program()
startuo_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startuo_program):
x = paddle.static.data(name='x', shape=[3, 2, 1])
out = paddle.static.nn.fc(x=x, size=1, num_flatten_dims=2)
return main_program

def test_program_need_update(self):
program = self.build_program()
self.assertTrue(program.desc.need_update())
program.desc.flush()
self.assertFalse(program.desc.need_update())

def test_program_hash_equal(self):
programs = []
for i in range(2):
programs.append(self.build_program())
program1, program2 = programs[0], programs[1]
# why not write as below?
# since the callstack attribute are not equal
#program1 = self.build_program()
#program2 = self.build_program()

self.assertTrue(program1.desc.need_update())
self.assertTrue(program2.desc.need_update())
# two program with same content
self.assertFalse(id(program1) == id(program2))
# print(program1, program2)
self.assertTrue(
program1.desc.cached_hash_str() == program2.desc.cached_hash_str())

self.assertFalse(program1.desc.need_update())
self.assertFalse(program2.desc.need_update())

def test_program_clone(self):
program = self.build_program()
program_clone = program.clone()

self.assertFalse(id(program) == id(program_clone))
self.assertTrue(program.desc.cached_hash_str() ==
program_clone.desc.cached_hash_str())

def test_program_update(self):
program = self.build_program()
hash1 = program.desc.cached_hash_str()
id1 = id(program)
# change mul's attr
program.current_block().ops[0]._set_attr('use_mkldnn', True)
program.current_block().ops[0]._set_attr('scale_x', 2.0)
hash2 = program.desc.cached_hash_str()
id2 = id(program)
self.assertTrue(id1 == id2)
self.assertFalse(hash1 == hash2)


if __name__ == '__main__':
unittest.main()

0 comments on commit e96dae8

Please sign in to comment.