Skip to content

Commit

Permalink
Free MasterSession::ReffedClientGraph::client_graph_ after it is no…
Browse files Browse the repository at this point in the history
… longer used.

The `ClientGraph` represents the post-pruning, pre-partitioning subgraph. It is not used at the master after partitioning concludes. A redundant copy is stored in the workers' memories (post-partitioning). This change releases the memory for the `ClientGraph` after partitioning concludes.

PiperOrigin-RevId: 221008316
  • Loading branch information
mrry authored and tensorflower-gardener committed Nov 11, 2018
1 parent 982ddb4 commit 30cf6f1
Showing 1 changed file with 42 additions and 31 deletions.
73 changes: 42 additions & 31 deletions tensorflow/core/distributed_runtime/master_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "tensorflow/core/distributed_runtime/master_session.h"

#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -64,26 +65,28 @@ namespace tensorflow {
class MasterSession::ReffedClientGraph : public core::RefCounted {
public:
ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
std::unique_ptr<ClientGraph> cg,
std::unique_ptr<ClientGraph> client_graph,
const SessionOptions& session_opts,
const StatsPublisherFactory& stats_publisher_factory,
bool is_partial, WorkerCacheInterface* worker_cache,
bool should_deregister)
: session_handle_(handle),
bg_opts_(bopts),
client_graph_(std::move(cg)),
client_graph_before_register_(std::move(client_graph)),
session_opts_(session_opts),
is_partial_(is_partial),
callable_opts_(bopts.callable_options),
worker_cache_(worker_cache),
should_deregister_(should_deregister) {
should_deregister_(should_deregister),
collective_graph_key_(
client_graph_before_register_->collective_graph_key) {
VLOG(1) << "Created ReffedClientGraph for node with "
<< client_graph()->graph.num_node_ids();
<< client_graph_before_register_->graph.num_node_ids();

stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);

// Initialize a name to node map for processing device stats.
for (Node* n : client_graph_->graph.nodes()) {
for (Node* n : client_graph_before_register_->graph.nodes()) {
name_to_node_.insert({n->name(), n});
}
}
Expand All @@ -98,12 +101,12 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
}
}

const ClientGraph* client_graph() { return client_graph_.get(); }

const CallableOptions& callable_options() { return callable_opts_; }

const BuildGraphOptions& build_graph_options() { return bg_opts_; }

int64 collective_graph_key() { return collective_graph_key_; }

std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
int64 execution_count,
const RunOptions& ropts) {
Expand Down Expand Up @@ -187,7 +190,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {

// Partitions the graph into subgraphs and registers them on
// workers.
Status RegisterPartitions(const PartitionOptions& popts);
Status RegisterPartitions(PartitionOptions popts);

// Runs one step of all partitions.
Status RunPartitions(const MasterEnv* env, int64 step_id,
Expand Down Expand Up @@ -230,13 +233,16 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
private:
const string session_handle_;
const BuildGraphOptions bg_opts_;
const std::unique_ptr<ClientGraph> client_graph_;

// NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns.
std::unique_ptr<ClientGraph> client_graph_before_register_ GUARDED_BY(mu_);
const SessionOptions session_opts_;
const bool is_partial_;
const CallableOptions callable_opts_;
WorkerCacheInterface* const worker_cache_; // Not owned.
std::unordered_map<StringPiece, Node*, StringPieceHasher> name_to_node_;
const bool should_deregister_;
const int64 collective_graph_key_;
std::atomic<int64> execution_count_ = {0};

// Graph partitioned into per-location subgraphs.
Expand Down Expand Up @@ -268,9 +274,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
mutable mutex mu_;

// Partition initialization and registration only needs to happen
// once. init_started_ && !init_done_ indicates the initialization
// is on going.
bool init_started_ GUARDED_BY(mu_) = false;
// once. `!client_graph_before_register_ && !init_done_.HasBeenNotified()`
// indicates the initialization is ongoing.
Notification init_done_;

// init_result_ remembers the initialization error if any.
Expand All @@ -286,7 +291,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {

// The actual graph partitioning and registration implementation.
Status DoBuildPartitions(
PartitionOptions pots,
PartitionOptions popts, ClientGraph* client_graph,
std::unordered_map<string, GraphDef>* out_partitions);
Status DoRegisterPartitions(
const PartitionOptions& popts,
Expand All @@ -311,14 +316,20 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
};

Status MasterSession::ReffedClientGraph::RegisterPartitions(
const PartitionOptions& popts) {
PartitionOptions popts) {
{ // Ensure register once.
mu_.lock();
if (!init_started_) {
init_started_ = true;
if (client_graph_before_register_) {
// The `ClientGraph` is no longer needed after partitions are registered.
// Since it can account for a large amount of memory, we consume it here,
// and it will be freed after concluding with registration.

std::unique_ptr<ClientGraph> client_graph;
std::swap(client_graph_before_register_, client_graph);
mu_.unlock();
std::unordered_map<string, GraphDef> graph_defs;
Status s = DoBuildPartitions(popts, &graph_defs);
popts.flib_def = client_graph->flib_def.get();
Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
if (s.ok()) {
// NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
// valid after the call to DoRegisterPartitions begins, so
Expand Down Expand Up @@ -394,19 +405,19 @@ void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
}

Status MasterSession::ReffedClientGraph::DoBuildPartitions(
PartitionOptions popts,
PartitionOptions popts, ClientGraph* client_graph,
std::unordered_map<string, GraphDef>* out_partitions) {
if (popts.need_to_record_start_times) {
CostModel cost_model(true);
cost_model.InitFromGraph(client_graph()->graph);
cost_model.InitFromGraph(client_graph->graph);
// TODO(yuanbyu): Use the real cost model.
// execution_state_->MergeFromGlobal(&cost_model);
SlackAnalysis sa(&client_graph()->graph, &cost_model);
SlackAnalysis sa(&client_graph->graph, &cost_model);
sa.ComputeAsap(&popts.start_times);
}

// Partition the graph.
return Partition(popts, &client_graph_->graph, out_partitions);
return Partition(popts, &client_graph->graph, out_partitions);
}

Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
Expand All @@ -415,7 +426,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
partitions_.reserve(graph_partitions.size());
Status s;
for (auto& name_def : graph_partitions) {
partitions_.resize(partitions_.size() + 1);
partitions_.emplace_back();
Part* part = &partitions_.back();
part->name = name_def.first;
TrackFeedsAndFetches(part, name_def.second, popts);
Expand Down Expand Up @@ -449,7 +460,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options();
c->req.set_collective_graph_key(client_graph()->collective_graph_key);
c->req.set_collective_graph_key(collective_graph_key_);
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
Expand Down Expand Up @@ -1545,14 +1556,13 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
// Registers subgraphs if haven't done so.
PartitionOptions popts;
popts.node_to_loc = SplitByWorker;
// The closures potps.{new_name,get_incarnation} are called synchronously in
// The closures popts.{new_name,get_incarnation} are called synchronously in
// RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep
// "this" alive during the closure.
popts.new_name = [this](const string& prefix) {
mutex_lock l(mu_);
return strings::StrCat(prefix, "_S", next_node_id_++);
};
popts.flib_def = rcg->client_graph()->flib_def.get();
popts.get_incarnation = [this](const string& name) -> int64 {
Device* d = devices_->FindDeviceByName(name);
if (d == nullptr) {
Expand Down Expand Up @@ -1580,7 +1590,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
popts.need_to_record_start_times = true;
}

TF_RETURN_IF_ERROR(rcg->RegisterPartitions(popts));
TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));

return Status::OK();
}
Expand Down Expand Up @@ -1784,10 +1794,10 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
Status s = run_status;
if (s.ok()) {
pss->end_micros = Env::Default()->NowMicros();
if (rcg->client_graph()->collective_graph_key !=
if (rcg->collective_graph_key() !=
BuildGraphOptions::kNoCollectiveGraphKey) {
env_->collective_executor_mgr->RetireStepId(
rcg->client_graph()->collective_graph_key, step_id);
env_->collective_executor_mgr->RetireStepId(rcg->collective_graph_key(),
step_id);
}
// Schedule post-processing and cleanup to be done asynchronously.
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
Expand Down Expand Up @@ -1846,14 +1856,15 @@ Status MasterSession::DoRunWithLocalExecution(

// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
uint64 step_id = NewStepId(rcg->collective_graph_key());
TRACEPRINTF("stepid %llu", step_id);

std::unique_ptr<ProfileHandler> ph;
FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);

Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
&cancellation_manager_, false);

cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
resp->mutable_metadata());
Expand Down Expand Up @@ -1910,7 +1921,7 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
// Prepare.
int64 count = rcg->get_and_increment_execution_count();

const uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
const uint64 step_id = NewStepId(rcg->collective_graph_key());
TRACEPRINTF("stepid %llu", step_id);

const RunOptions& run_options = rcg->callable_options().run_options();
Expand Down

0 comments on commit 30cf6f1

Please sign in to comment.