Skip to content

Commit

Permalink
[SE,XLA] Switch to using multiple streams in xla_device_context
Browse files Browse the repository at this point in the history
Instead of having one stream for compute, host-to-device and device-to-host transfers, switch to having separate streams, just like the GPU does.
Add a se::Event field to XlaTensor to allow accurate inter-stream dependencies to be created.

As part of this:
 - Fix TransferManager::TransferLiteralFrom/ToDevice to correctly make generated substreams wait on their master stream.
 - Fix Stream::BlockHostUntilDone() to not block on or return substreams. This behavior is completely broken and not only nondeterministically returns substreams to the pool but causes indefinite hangs with the HostStream.

PiperOrigin-RevId: 203726543
  • Loading branch information
tensorflower-gardener committed Jul 9, 2018
1 parent caf711b commit 955e356
Show file tree
Hide file tree
Showing 17 changed files with 341 additions and 146 deletions.
5 changes: 3 additions & 2 deletions tensorflow/compiler/jit/kernels/xla_launch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
const XlaDevice::Metadata* metadata = nullptr;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
bool allocate_xla_tensors = s.ok();
bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams();

// Get the platform_id_ for XLA_* devices.
if (platform_id_ == nullptr) {
Expand Down Expand Up @@ -180,8 +181,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {

VLOG(1) << "Executing XLA Computation...";

XlaComputationLaunchContext launch_context(client, xla_allocator,
allocate_xla_tensors);
XlaComputationLaunchContext launch_context(
client, xla_allocator, allocate_xla_tensors, use_multiple_streams);
launch_context.PopulateInputs(ctx, kernel, variables);

// Execute the computation.
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/compiler/jit/xla_compile_on_demand_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,

// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
client, client->backend().memory_allocator(), true);
client, client->backend().memory_allocator(),
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());

launch_context.PopulateInputs(ctx, result, variables);

Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/jit/xla_cpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
DEVICE_CPU_XLA_JIT, options, name_prefix,
registration,
/*transfer_as_literal=*/false,
/*use_multiple_streams=*/false,
/*shape_representation_fn=*/{},
/*padded_shape_fn=*/{}, &device));
devices->push_back(device.release());
Expand Down
70 changes: 55 additions & 15 deletions tensorflow/compiler/jit/xla_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const string& jit_device_name, const SessionOptions& options,
const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
Expand All @@ -151,22 +151,24 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
strings::StrCat("device: ", device_name, " device"));

device->reset(new XlaDevice(
options, attrs, device_ordinal, DeviceType(jit_device_name),
platform.ValueOrDie(), transfer_as_literal, shape_representation_fn,
padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
device->reset(
new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name),
platform.ValueOrDie(), transfer_as_literal,
use_multiple_streams, shape_representation_fn,
padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
return Status::OK();
}

XlaDevice::Metadata::Metadata(
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
PaddedShapeFn padded_shape_fn)
PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
: device_ordinal_(device_ordinal),
device_type_(device_type),
platform_(platform),
shape_representation_fn_(std::move(shape_representation_fn)),
padded_shape_fn_(std::move(padded_shape_fn)) {}
padded_shape_fn_(std::move(padded_shape_fn)),
use_multiple_streams_(use_multiple_streams) {}

int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }

Expand Down Expand Up @@ -200,16 +202,18 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn)
: LocalDevice(options, attrs),
xla_metadata_(device_ordinal, platform, jit_device_name,
shape_representation_fn, padded_shape_fn),
shape_representation_fn, padded_shape_fn,
use_multiple_streams),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
platform_(platform),
use_multiple_streams_(use_multiple_streams),
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name;
Expand Down Expand Up @@ -253,6 +257,30 @@ xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
return stream_.get();
}

xla::StatusOr<se::Stream*> XlaDevice::GetDeviceToHostStream() {
if (!use_multiple_streams_) {
return GetStream();
}
if (!device_to_host_stream_) {
xla::Backend* backend = client()->mutable_backend();
TF_ASSIGN_OR_RETURN(device_to_host_stream_,
backend->BorrowStream(device_ordinal_));
}
return device_to_host_stream_.get();
}

xla::StatusOr<se::Stream*> XlaDevice::GetHostToDeviceStream() {
if (!use_multiple_streams_) {
return GetStream();
}
if (!host_to_device_stream_) {
xla::Backend* backend = client()->mutable_backend();
TF_ASSIGN_OR_RETURN(host_to_device_stream_,
backend->BorrowStream(device_ordinal_));
}
return host_to_device_stream_.get();
}

Status XlaDevice::CreateAndSetGpuDeviceInfo() {
if (gpu_device_info_ == nullptr) {
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
Expand All @@ -263,8 +291,9 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
// gpu_device_info_->default_context.
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
gpu_device_info_->stream = stream;
gpu_device_info_->default_context = new XlaDeviceContext(
stream, client(), transfer_as_literal_, shape_representation_fn_);
gpu_device_info_->default_context =
new XlaDeviceContext(stream, stream, stream, client(),
transfer_as_literal_, shape_representation_fn_);
set_tensorflow_gpu_device_info(gpu_device_info_.get());
}

Expand All @@ -276,10 +305,16 @@ Status XlaDevice::FillContextMap(const Graph* graph,
VLOG(1) << "XlaDevice::FillContextMap";
device_context_map->resize(graph->num_node_ids());
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
GetDeviceToHostStream());
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
GetHostToDeviceStream());

// Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocator({});
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
shape_representation_fn_);
auto ctx = new XlaDeviceContext(
stream, host_to_device_stream, device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_);
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref();
Expand Down Expand Up @@ -326,8 +361,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
XlaTransferManager manager(stream, client(), transfer_as_literal_,
shape_representation_fn_);
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
GetDeviceToHostStream());
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
GetHostToDeviceStream());
XlaTransferManager manager(stream, host_to_device_stream,
device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_);
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;
Expand Down
22 changes: 20 additions & 2 deletions tensorflow/compiler/jit/xla_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class XlaDevice : public LocalDevice {
Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
PaddedShapeFn padded_shape_fn);
PaddedShapeFn padded_shape_fn, bool use_multiple_streams);

// The index of the device on this host.
int device_ordinal() const;
Expand All @@ -70,12 +70,15 @@ class XlaDevice : public LocalDevice {
}
const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }

bool UseMultipleStreams() const { return use_multiple_streams_; }

private:
const int device_ordinal_;
const DeviceType device_type_;
se::Platform* platform_; // Not owned.
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
PaddedShapeFn padded_shape_fn_;
const bool use_multiple_streams_;

TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
};
Expand All @@ -89,14 +92,16 @@ class XlaDevice : public LocalDevice {
// 'transfer_as_literal' is true if device<->host transfers must be done using
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
// ThenMemcpy instead.
// If 'use_multiple_streams' is true, we create separate streams for
// host-to-device and device-to-host communication.
// If padded_shape_fn is empty, a default implementation that returns
// the on-host shape is used.
static Status Create(
const string& platform_name, const string& device_name,
int device_ordinal, const string& jit_device_name,
const SessionOptions& options, const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device);

Expand All @@ -106,6 +111,7 @@ class XlaDevice : public LocalDevice {
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn);
~XlaDevice() override;
Expand All @@ -126,6 +132,8 @@ class XlaDevice : public LocalDevice {
xla::LocalClient* client() const;
const Metadata& metadata() { return xla_metadata_; }
xla::StatusOr<se::Stream*> GetStream();
xla::StatusOr<se::Stream*> GetHostToDeviceStream();
xla::StatusOr<se::Stream*> GetDeviceToHostStream();

// If not already set, create and set GpuDeviceInfo.
// Not thread-safe
Expand All @@ -146,6 +154,16 @@ class XlaDevice : public LocalDevice {
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
xla::Backend::StreamPtr stream_;
// If true, only stream_ is valid and all computation and transfers use
// stream_. If false, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream.
bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
xla::Backend::StreamPtr host_to_device_stream_;
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
xla::Backend::StreamPtr device_to_host_stream_;
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
bool transfer_as_literal_;
Expand Down
Loading

0 comments on commit 955e356

Please sign in to comment.