Skip to content

Commit

Permalink
Build fully connected graph which edges across called computations.
Browse files Browse the repository at this point in the history
Restructured sharding passes to propagate sharding on pass-through instructions which now the placer does not assign anymore (GTEs, tuples, bitcast, parameters, ...).

PiperOrigin-RevId: 203591020
  • Loading branch information
tensorflower-gardener committed Jul 8, 2018
1 parent 6d5b8b7 commit 35287be
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tensorflow/compiler/xla/service/hlo_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,34 @@ HloSharding HloSharding::Tuple(
const Shape& tuple_shape,
tensorflow::gtl::ArraySlice<HloSharding> shardings) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
for (auto& sharding : shardings) {
CHECK(!sharding.IsTuple()) << sharding.ToString();
}
std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
<< "Flat list has " << flattened_list.size() << ", required "
<< RequiredLeaves(tuple_shape);
return HloSharding(flattened_list);
}

HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
const HloSharding& sharding) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
CHECK(!sharding.IsTuple()) << sharding.ToString();
int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape);
std::vector<HloSharding> flattened_list;
flattened_list.reserve(leaf_count);
for (int64 i = 0; i < leaf_count; ++i) {
flattened_list.push_back(sharding);
}
return HloSharding(flattened_list);
}

HloSharding HloSharding::Single(const Shape& shape,
const HloSharding& sharding) {
return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding;
}

string HloSharding::ToString() const {
if (IsTuple()) {
std::vector<string> parts;
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/compiler/xla/service/hlo_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ class HloSharding {
static HloSharding Tuple(const Shape& tuple_shape,
tensorflow::gtl::ArraySlice<HloSharding> shardings);

// Creates a new sharding for a tuple type, with a single input sharding
// repeated on each leaf.
static HloSharding SingleTuple(const Shape& tuple_shape,
const HloSharding& sharding);

// If shape is an array, returns sharding, otherwise returns the tuple shaped
// sharding with all the leaf nodes having the same input sharding.
static HloSharding Single(const Shape& shape, const HloSharding& sharding);

// Create a new sharding from a protobuf OpSharding.
static StatusOr<HloSharding> FromProto(const OpSharding& proto);

Expand Down

0 comments on commit 35287be

Please sign in to comment.