Skip to content

Commit

Permalink
[XLA] Teach the BF16 normalizer that Sort can have a tuple output.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 206647898
  • Loading branch information
mkuperst authored and tensorflower-gardener committed Jul 30, 2018
1 parent b70eb11 commit c57870c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 23 deletions.
65 changes: 42 additions & 23 deletions tensorflow/compiler/xla/service/bfloat16_normalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {

Status DefaultAction(HloInstruction* hlo) override;

// Special handling for cross-replica-sum which can have a tuple output.
// Special handling for cross-replica-sum and sort which can have a tuple
// output.
Status HandleCrossReplicaSum(HloInstruction* crs) override;
Status HandleSort(HloInstruction* sort) override;

static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) {
Expand All @@ -49,6 +51,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
// conversions between F32 and BF16 to make it supported.
Status HandleInstruction(HloInstruction* hlo);

// Handle instructions with tuple outputs by examining each output
// independently.
Status HandleMultipleOutputs(HloInstruction* hlo);

// Inserts a conversion HLO that changes the given HLO's output type.
Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to,
HloComputation* computation);
Expand Down Expand Up @@ -148,30 +154,43 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
HloInstruction* crs) {
if (!ShapeUtil::IsTuple(crs->shape())) {
return HandleInstruction(crs);
} else {
return HandleMultipleOutputs(crs);
}
}

Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) {
if (!ShapeUtil::IsTuple(sort->shape())) {
return HandleInstruction(sort);
} else {
return HandleMultipleOutputs(sort);
}
}

std::vector<PrimitiveType> operand_types(crs->operand_count());
std::vector<PrimitiveType> output_types(crs->operand_count());
Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
HloInstruction* hlo) {
std::vector<PrimitiveType> operand_types(hlo->operand_count());
std::vector<PrimitiveType> output_types(hlo->operand_count());
int64 f32_count = 0;
int64 bf16_count = 0;
bool has_unsupported_bf16_operand = false;
bool has_unsupported_bf16_output = false;
for (int64 i = 0; i < crs->operand_count(); ++i) {
operand_types[i] = crs->operand(i)->shape().element_type();
output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type();
for (int64 i = 0; i < hlo->operand_count(); ++i) {
operand_types[i] = hlo->operand(i)->shape().element_type();
output_types[i] = ShapeUtil::GetSubshape(hlo->shape(), {i}).element_type();
if (operand_types[i] == F32) {
f32_count += 1;
} else if (operand_types[i] == BF16) {
bf16_count += 1;
if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) {
if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
has_unsupported_bf16_operand = true;
}
}
if (output_types[i] == F32) {
f32_count += 1;
} else if (output_types[i] == BF16) {
bf16_count += 1;
if (!bfloat16_support_->SupportsBF16Output(*crs)) {
if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
has_unsupported_bf16_output = true;
}
}
Expand All @@ -185,55 +204,55 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
if (operand_types[i] != BF16) {
return false;
}
if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) {
if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
return true;
}
if (bfloat16_support_->SupportsMixedPrecisions(*crs)) {
if (bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
return false;
}
return has_unsupported_bf16_operand || has_unsupported_bf16_output ||
f32_count > 0;
};

for (int64 i = 0; i < crs->operand_count(); ++i) {
for (int64 i = 0; i < hlo->operand_count(); ++i) {
if (should_convert_operand(i)) {
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_));
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
f32_count += 1;
bf16_count -= 1;
}
}

if (!has_unsupported_bf16_output &&
(bfloat16_support_->SupportsMixedPrecisions(*crs) || f32_count == 0 ||
(bfloat16_support_->SupportsMixedPrecisions(*hlo) || f32_count == 0 ||
bf16_count == 0)) {
return Status::OK();
}

std::vector<HloInstruction*> materialized_users = crs->users();
std::vector<HloInstruction*> output_elements(crs->operand_count());
auto original_shape = crs->shape();
for (int64 i = 0; i < crs->operand_count(); ++i) {
auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i});
std::vector<HloInstruction*> materialized_users = hlo->users();
std::vector<HloInstruction*> output_elements(hlo->operand_count());
auto original_shape = hlo->shape();
for (int64 i = 0; i < hlo->operand_count(); ++i) {
auto subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), {i});
if (output_types[i] != BF16) {
output_elements[i] = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(*subshape, crs, i));
HloInstruction::CreateGetTupleElement(*subshape, hlo, i));
continue;
}
subshape->set_element_type(F32);
auto gte = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(*subshape, crs, i));
HloInstruction::CreateGetTupleElement(*subshape, hlo, i));
output_elements[i] =
computation_->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(*subshape, BF16), gte));
}
auto tuple = computation_->AddInstruction(
HloInstruction::CreateTuple(output_elements));

// Use the crs' shape temporarily, in order to pass checks in
// Use the hlo' shape temporarily, in order to pass checks in
// ReplaceUseWith.
*tuple->mutable_shape() = crs->shape();
*tuple->mutable_shape() = hlo->shape();
for (auto* user : materialized_users) {
TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple));
TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple));
}
*tuple->mutable_shape() = original_shape;
return Status::OK();
Expand Down
27 changes: 27 additions & 0 deletions tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,33 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32);
}

TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {1024});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024});
Shape s32_shape = ShapeUtil::MakeShape(BF16, {1024});

HloInstruction* key = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_shape, "key"));
HloInstruction* value = builder.AddInstruction(
HloInstruction::CreateParameter(1, s32_shape, "value"));

HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort(
ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0));

auto computation = module->AddEntryComputation(builder.Build());

EXPECT_TRUE(Normalize(module.get()));

EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
EXPECT_EQ(sort->operand(0)->shape().element_type(), F32);
EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32);
}

// Tests that the normalization should not cause unsupported mixed precision due
// to resolving unsupported BF16 operand.
TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
Expand Down

0 comments on commit c57870c

Please sign in to comment.