Skip to content

Commit

Permalink
[FLR] Consolidate logic for configuring Executor::Args in one place.
Browse files Browse the repository at this point in the history
As a bonus, avoid heap-allocating `Executor::Args` for all local function calls. This slightly improves the performance of the callback when a function completes.

PiperOrigin-RevId: 220180860
  • Loading branch information
mrry authored and tensorflower-gardener committed Nov 5, 2018
1 parent 8014c26 commit 5babd0d
Showing 1 changed file with 37 additions and 45 deletions.
82 changes: 37 additions & 45 deletions tensorflow/core/common_runtime/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,11 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
AttrValueMap FixAttrs(const AttrSlice& attrs);
void RunRemote(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
Executor::Args* exec_args, Item* item, DoneCallback done);
Item* item, DoneCallback done);

void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts,
CallFrameInterface* frame,
Executor::Args* exec_args);

TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
};
Expand Down Expand Up @@ -858,41 +862,50 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
return CreateItem(handle, item);
}

void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame,
Executor::Args* exec_args) {
// Inherit the step_id from the caller.
exec_args->step_id = run_opts.step_id;
exec_args->rendezvous = run_opts.rendezvous;
exec_args->stats_collector = run_opts.stats_collector;
exec_args->cancellation_manager = run_opts.cancellation_manager;
exec_args->step_container = run_opts.step_container;
if (run_opts.runner) {
exec_args->runner = *run_opts.runner;
} else {
exec_args->runner = default_runner_;
}
exec_args->collective_executor = run_opts.collective_executor;
exec_args->call_frame = frame;
}

void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
Executor::Args* exec_args,
Item* item, DoneCallback done) {
DCHECK(exec_args->call_frame == nullptr);
string target_device = parent_->GetDeviceName(handle);
string source_device = opts.source_device;
Rendezvous* rendezvous = opts.rendezvous;
DeviceContext* device_context;
Status s = parent_->GetDeviceContext(target_device, &device_context);
if (!s.ok()) {
delete exec_args;
done(s);
return;
}
int64 src_incarnation, target_incarnation;
s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
if (!s.ok()) {
delete exec_args;
done(s);
return;
}

const FunctionBody* fbody = GetFunctionBody(handle);
FunctionCallFrame* frame =
new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
exec_args->call_frame = frame;
if (!s.ok()) {
delete frame;
delete exec_args;
done(s);
return;
}
Executor::Args* exec_args = new Executor::Args;
ExecutorArgsFromOptions(opts, frame, exec_args);

std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs;
args_alloc_attrs.reserve(fbody->arg_types.size());
Expand Down Expand Up @@ -938,28 +951,27 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
return;
}
item->exec->RunAsync(
*exec_args, [frame, rets, done, source_device, target_device,
target_incarnation, rendezvous, device_context,
remote_args, exec_args, rets_alloc_attrs,
allow_dead_tensors](const Status& status) {
*exec_args,
[frame, rets, done, source_device, target_device,
target_incarnation, rendezvous, device_context, remote_args,
rets_alloc_attrs, allow_dead_tensors](const Status& status) {
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets, allow_dead_tensors);
}
delete frame;
if (!s.ok()) {
delete remote_args;
delete exec_args;
done(s);
return;
}
s = ProcessFunctionLibraryRuntime::SendTensors(
target_device, source_device, "ret_", target_incarnation,
*rets, device_context, rets_alloc_attrs, rendezvous);
delete remote_args;
delete exec_args;
done(s);
});
delete exec_args;
});
}

Expand Down Expand Up @@ -992,54 +1004,43 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
}
DCHECK(run_opts.runner != nullptr);

Executor::Args* exec_args = new Executor::Args;
// Inherit the step_id from the caller.
exec_args->step_id = run_opts.step_id;
exec_args->rendezvous = run_opts.rendezvous;
exec_args->stats_collector = run_opts.stats_collector;
exec_args->cancellation_manager = run_opts.cancellation_manager;
exec_args->step_container = run_opts.step_container;
exec_args->runner = *run_opts.runner;
exec_args->collective_executor = run_opts.collective_executor;

Item* item = nullptr;
Status s = GetOrCreateItem(handle, &item);
if (!s.ok()) {
delete exec_args;
done(s);
return;
}

if (run_opts.remote_execution) {
// NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
RunRemote(run_opts, handle, args, rets, exec_args, item, done);
RunRemote(run_opts, handle, args, rets, item, done);
return;
}

const FunctionBody* fbody = GetFunctionBody(handle);
FunctionCallFrame* frame =
new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
exec_args->call_frame = frame;
s = frame->SetArgs(args);
if (!s.ok()) {
delete frame;
delete exec_args;
done(s);
return;
}

Executor::Args exec_args;
ExecutorArgsFromOptions(opts, frame, &exec_args);

bool allow_dead_tensors = opts.allow_dead_tensors;
item->exec->RunAsync(
// Executor args
*exec_args,
exec_args,
// Done callback.
[frame, rets, done, exec_args, allow_dead_tensors](const Status& status) {
[frame, rets, done, allow_dead_tensors](const Status& status) {
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets, allow_dead_tensors);
}
delete frame;
delete exec_args;
done(s);
});
}
Expand Down Expand Up @@ -1084,16 +1085,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
DCHECK(run_opts.runner != nullptr);

Executor::Args exec_args;
// Inherit the step_id from the caller.
exec_args.step_id = run_opts.step_id;
exec_args.rendezvous = run_opts.rendezvous;
exec_args.stats_collector = run_opts.stats_collector;
exec_args.cancellation_manager = run_opts.cancellation_manager;
exec_args.collective_executor = run_opts.collective_executor;
exec_args.step_container = run_opts.step_container;
exec_args.runner = *run_opts.runner;
exec_args.call_frame = frame;

ExecutorArgsFromOptions(opts, frame, &exec_args);
item->exec->RunAsync(exec_args, std::move(done));
}

Expand Down

0 comments on commit 5babd0d

Please sign in to comment.