Skip to content

Commit

Permalink
Merge pull request tensorflow#14123 from andrewharp/branch_174023371
Browse files Browse the repository at this point in the history
Branch 174023371
  • Loading branch information
andrewharp authored Oct 31, 2017
2 parents 123749f + 648993e commit e64bc92
Show file tree
Hide file tree
Showing 236 changed files with 14,650 additions and 2,688 deletions.
1 change: 1 addition & 0 deletions tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ filegroup(
"//tensorflow/contrib/crf:all_files",
"//tensorflow/contrib/cudnn_rnn:all_files",
"//tensorflow/contrib/data:all_files",
"//tensorflow/contrib/data/kernels:all_files",
"//tensorflow/contrib/data/python/kernel_tests:all_files",
"//tensorflow/contrib/data/python/ops:all_files",
"//tensorflow/contrib/decision_trees/proto:all_files",
Expand Down
7 changes: 2 additions & 5 deletions tensorflow/compiler/tf2xla/xla_compilation_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,17 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
DeviceNameUtils::ParseFullName(op_kernel->requested_device(), &parsed),
errors::Internal("Unable to parse device name: ",
op_kernel->requested_device()));
xla::OpDeviceAssignment assignment;
// If no device ID assignment is found, XLA is free to use whatever device it
// wants. In practice this usually has the effect of placing things on
// device 0.
if (parsed.has_id) {
assignment.set_has_device(true);
assignment.set_device(parsed.id);
b->SetSharding(xla::ShardingBuilder::AssignDevice(parsed.id));
}
b->SetDeviceAssignment(assignment);

op_kernel->Compute(context);

b->ClearOpMetadata();
b->ClearDeviceAssignment();
b->ClearSharding();
VLOG(4) << "Done";
}

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class XlaCompiledCpuFunction {
const void** args, void** temps)>;

// StaticData represents the state necessary to run an XLA-compiled
// function. For JIT this is backed by data in XlaCompiledCpuFunctionJit; for
// function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for
// AOT this is backed by data compiled into the object file.
struct StaticData {
// The raw function to call.
Expand Down
24 changes: 21 additions & 3 deletions tensorflow/compiler/xla/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ template <typename T>
class Array {
public:
// Creates a new array with the specified dimensions.
explicit Array(const std::vector<int64>& sizes) : Array(sizes, T()) {}
explicit Array(tensorflow::gtl::ArraySlice<int64> sizes)
: Array(sizes, T()) {}

// Creates a new array with the specified dimensions and specified value for
// every cell.
Array(const std::vector<int64>& sizes, T value)
: sizes_(sizes), values_(new T[num_elements()]) {
Array(tensorflow::gtl::ArraySlice<int64> sizes, T value)
: sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) {
Fill(value);
}

Expand Down Expand Up @@ -192,6 +193,18 @@ class Array {
return values_[calculate_index(indexes)];
}

// Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array.
const T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) const {
return values_[calculate_index(indexes)];
}

// Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array.
T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) {
return values_[calculate_index(indexes)];
}

// Low-level accessor for stuff like memcmp, handle with care. Returns pointer
// to the underlying storage of the array (similarly to std::vector::data()).
T* data() const {
Expand All @@ -218,6 +231,11 @@ class Array {
std::multiplies<int64>());
}

const T* begin() const { return &values_[0]; }
T* begin() { return &values_[0]; }
const T* end() const { return &values_[num_elements()]; }
T* end() { return &values_[num_elements()]; }

bool operator==(const Array<T>& other) const {
if (sizes_.size() != other.sizes_.size()) {
return false;
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ cc_library(
":computation",
":global_data",
":padding",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
Expand Down
11 changes: 3 additions & 8 deletions tensorflow/compiler/xla/client/computation_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1794,14 +1794,9 @@ StatusOr<Computation> ComputationBuilder::Build() {

void ComputationBuilder::AddCommonFieldsToOpRequest(OpRequest* request) const {
*request->mutable_metadata() = metadata_;
*request->mutable_device_assignment() = device_assignment_;
}

void ComputationBuilder::ClearDeviceAssignment() { device_assignment_.Clear(); }

void ComputationBuilder::SetDeviceAssignment(
const OpDeviceAssignment& assignment) {
device_assignment_ = assignment;
if (sharding_) {
*request->mutable_sharding() = *sharding_;
}
}

/* static */ ConvolutionDimensionNumbers
Expand Down
62 changes: 58 additions & 4 deletions tensorflow/compiler/xla/client/computation_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <utility>

#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
Expand All @@ -42,6 +43,58 @@ limitations under the License.

namespace xla {

class ShardingBuilder {
public:
// A shaped array used to describe the assignment of tiles to devices.
using TileAssignment = Array<int64>;

// Creates a replicated sharding - replicate a tensor on every device.
static OpSharding Replicate() {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
return result;
}
// Creates a sharding that assigns a tensor to just one device.
static OpSharding AssignDevice(int device) {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
result.add_tile_assignment_dimensions(1);
result.add_tile_assignment_devices(device);
return result;
}
// Creates a tiled sharding with the given tile shape and assignment of tiles
// to devices.
static OpSharding Tile(Shape tile_shape,
const TileAssignment& tile_assignment) {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
for (int64 dim : tile_assignment.dimensions()) {
result.add_tile_assignment_dimensions(dim);
}
for (uint32 device : tile_assignment) {
result.add_tile_assignment_devices(device);
}
return result;
}
// Creates a sharding in one dimension, with the given tile shape which must
// be rank 1 and using devices 0..num_tiles.
static OpSharding Tile1D(Shape tile_shape, int64 num_tiles) {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);

CHECK_EQ(ShapeUtil::Rank(tile_shape), 1);
std::vector<int64> dimensions(1, num_tiles);
auto& tile_dimension = (*tile_shape.mutable_dimensions())[0];
tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles);
*result.mutable_tile_shape() = tile_shape;
result.add_tile_assignment_dimensions(num_tiles);
for (int64 i = 0; i < num_tiles; ++i) {
result.add_tile_assignment_devices(i);
}
return result;
}
};

// Wraps an XLA client with a convenient interface for building up
// computations. Any errors encountered in building up the computation are
// deferred from being handled until Build() is called.
Expand Down Expand Up @@ -78,11 +131,11 @@ class ComputationBuilder {

// Sets an OpDeviceAssignment that will be attached to all instructions
// until cleared.
void SetDeviceAssignment(const OpDeviceAssignment& assignment);
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }

// Clears the device assignment. Ops will be placed according to the default
// placement policy.
void ClearDeviceAssignment();
void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }

// Sets the builder to a mode where it will die immediately when an error is
// encountered, rather than producing it in a deferred fashion when Build() is
Expand Down Expand Up @@ -894,8 +947,9 @@ class ComputationBuilder {
// throughout the TensorFlow op kernel implementations).
OpMetadata metadata_;

// Device assignment for the operator.
OpDeviceAssignment device_assignment_;
// Sharding for this operator. This is structured as a "model"-like operation,
// in order to simplify client code, similar to metadata_.
tensorflow::gtl::optional<OpSharding> sharding_;

TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
};
Expand Down
19 changes: 19 additions & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ cc_library(
"hlo_instruction.cc",
"hlo_module.cc",
"hlo_opcode.cc",
"hlo_sharding.cc",
],
hdrs = [
"dfs_hlo_visitor.h",
Expand All @@ -139,13 +140,15 @@ cc_library(
"hlo_instruction.h",
"hlo_module.h",
"hlo_opcode.h",
"hlo_sharding.h",
],
deps = [
":hlo_module_config",
":hlo_proto",
":hlo_reachability",
":name_uniquer",
":versioned_computation_handle",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_tree",
Expand Down Expand Up @@ -236,6 +239,22 @@ tf_cc_test(
],
)

tf_cc_test(
name = "hlo_sharding_test",
srcs = ["hlo_sharding_test.cc"],
deps = [
":hlo",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)

cc_library(
name = "call_graph",
srcs = ["call_graph.cc"],
Expand Down
Loading

0 comments on commit e64bc92

Please sign in to comment.