Skip to content

Commit

Permalink
[TF:XLA] Improve conditional_test by using parameters instead of cons…
Browse files Browse the repository at this point in the history
…tants.

Before this change the conditional simplifier would fold away almost all
conditional expressions.

PiperOrigin-RevId: 204790426
  • Loading branch information
tensorflower-gardener committed Jul 16, 2018
1 parent 3ad62d7 commit bf0f48d
Showing 1 changed file with 63 additions and 41 deletions.
104 changes: 63 additions & 41 deletions tensorflow/compiler/xla/tests/conditional_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,88 +172,95 @@ class ConditionalOpTest : public ClientLibraryTestBase {
// Test true and false computations that do not take any parameters.
XLA_TEST_F(ConditionalOpTest, Parameters0) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, true);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {});
auto true_computation = CreateR0ConstantComputation(56.0f);
auto false_computation = CreateR0ConstantComputation(12.0f);
Conditional(pred, operands, true_computation, operands, false_computation);

ComputeAndCompareR0<float>(&builder, 56.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 56.0f, {pred_arg.get()}, error_spec_);
}

// Test true and false computations that take in 1 parameter.
XLA_TEST_F(ConditionalOpTest, Parameters1) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.0f);
auto operand2 = ConstantR0<float>(&builder, 12.0f);
auto identity = CreateR0IdentityComputation();
Conditional(pred, operand1, identity, operand2, identity);

ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}

// Test conditional with two different computations in the true and false cases
// that take in different arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
CreateR0FloorComputation());

ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}

// Test conditional with two different computations in the true and false cases
// that take in the same arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand = ConstantR0<float>(&builder, 12.6f);
Conditional(pred, operand, CreateR0CeilComputation(), operand,
CreateR0FloorComputation());

ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}

// Test conditional with the same computation in the true and false cases but
// take in different arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
auto floor = CreateR0FloorComputation();
Conditional(pred, operand1, floor, operand2, floor);

ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}

// Test conditional with the same computation in the true and false cases that
// take in the same arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand = ConstantR0<float>(&builder, 12.6f);
auto floor = CreateR0FloorComputation();
Conditional(pred, operand, floor, operand, floor);

ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}

// Test conditional with different instances of the same computation in the true
// and false cases.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
Conditional(pred, operand1, CreateR0FloorComputation(), operand2,
CreateR0FloorComputation());

ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}

// Test the case when a call invokes a computation that contains a conditional.
Expand All @@ -268,75 +275,83 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
auto inner_builder_result = inner_builder.Build();

XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
Call(&builder, inner_builder_result.ConsumeValueOrDie(),
{pred, operand1, operand2});

ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}

// Test true and false computations that take in 2 parameters and predicate is
// true.
XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, true);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.0f);
auto operand2 = ConstantR0<float>(&builder, 12.0f);
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
CreateR0TupleSubComputation());

ComputeAndCompareR0<float>(&builder, 68.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 68.0f, {pred_arg.get()}, error_spec_);
}

// Test true and false computations that take in 2 parameters and predicate is
// false.
XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.0f);
auto operand2 = ConstantR0<float>(&builder, 12.0f);
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
CreateR0TupleSubComputation());

ComputeAndCompareR0<float>(&builder, 44.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 44.0f, {pred_arg.get()}, error_spec_);
}

// Test true and false computations that take in 2 array parameters and
// predicate is true.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, true);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
CreateR1TupleSubComputation());

ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {}, error_spec_);
ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {pred_arg.get()},
error_spec_);
}

// Test true and false computations that take in 2 array parameters and
// predicate is false.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
CreateR1TupleSubComputation());

ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {}, error_spec_);
ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {pred_arg.get()},
error_spec_);
}

// Test true and false computations that return a tuple of scalars.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {ConstantR0<float>(&builder, 12.2f),
ConstantR0<float>(&builder, 25.6f)});
Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
Expand All @@ -346,13 +361,14 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
&builder,
*LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
LiteralUtil::CreateR0<float>(25.0f).get()}),
{}, error_spec_);
{pred_arg.get()}, error_spec_);
}

// Test true and false computations that return a tuple of arrays.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, true);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operands =
Tuple(&builder, {ConstantR1<float>(&builder, {12.2f, 15.8f}),
ConstantR1<float>(&builder, {25.6f, 29.2f})});
Expand All @@ -364,7 +380,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
*LiteralUtil::MakeTuple(
{LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
{}, error_spec_);
{pred_arg.get()}, error_spec_);
}

// Test true and false computations that return a tuple of a predicate, a
Expand Down Expand Up @@ -393,7 +409,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
EXPECT_IS_OK(false_builder_result.status());

XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, true);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {});
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
Expand All @@ -404,7 +421,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
{LiteralUtil::CreateR0<bool>(true).get(),
LiteralUtil::CreateR0<float>(12.2f).get(),
LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
{}, error_spec_);
{pred_arg.get()}, error_spec_);
}

// Test true and false computations that return a nested tuple.
Expand Down Expand Up @@ -438,7 +455,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
EXPECT_IS_OK(false_builder_result.status());

XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {});
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
Expand All @@ -454,7 +472,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
{LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
LiteralUtil::CreateR0<float>(9.3f).get()})
.get()}),
{}, error_spec_);
{pred_arg.get()}, error_spec_);
}

// Test conditional that takes in scalar operands in the form of external
Expand Down Expand Up @@ -515,16 +533,18 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
EXPECT_IS_OK(inner_builder_result.status());

XlaBuilder builder(TestName());
auto pred1 = ConstantR0<bool>(&builder, true);
auto pred2 = ConstantR0<bool>(&builder, false);
XlaOp pred1, pred2;
auto pred1_arg = CreateR0Parameter<bool>(true, 0, "pred1", &builder, &pred1);
auto pred2_arg = CreateR0Parameter<bool>(false, 1, "pred2", &builder, &pred2);
auto operand1 = ConstantR0<float>(&builder, 1.1f);
auto operand2 = ConstantR0<float>(&builder, 12.2f);
auto operand3 = ConstantR0<float>(&builder, 43.3f);
auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(),
operand3, CreateR0IdentityComputation());

ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 12.0f,
{pred1_arg.get(), pred2_arg.get()}, error_spec_);
}

XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
Expand All @@ -543,13 +563,14 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
EXPECT_IS_OK(inner_builder_result.status());

XlaBuilder builder(TestName());
auto pred2 = ConstantR0<bool>(&builder, false);
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 1.1f);
auto operand2 = ConstantR0<float>(&builder, 12.2f);
auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
auto tuple_operand = Tuple(&builder, {pred, operand1, operand2});
Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand});

ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}

// Test a mismatch in the shape of the true operand and true computation.
Expand Down Expand Up @@ -604,16 +625,17 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {

auto test_swap = [&](float a, float b) {
XlaBuilder builder(TestName());
auto x = ConstantR0<float>(&builder, a);
auto y = ConstantR0<float>(&builder, b);
XlaOp x, y;
auto x_arg = CreateR0Parameter<float>(a, 0, "x", &builder, &x);
auto y_arg = CreateR0Parameter<float>(b, 1, "y", &builder, &y);
auto tuple_operand = Tuple(&builder, {x, y});
Call(&builder, main, {tuple_operand});

ComputeAndCompareTuple(
&builder,
*LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
LiteralUtil::CreateR0<float>(b).get()}),
{}, error_spec_);
{x_arg.get(), y_arg.get()}, error_spec_);
};

test_swap(3.11f, 9.4f);
Expand Down

0 comments on commit bf0f48d

Please sign in to comment.