Skip to content

Commit

Permalink
Convert a few more tests to hlo_verified_test_base.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 212730899
  • Loading branch information
dimvar authored and tensorflower-gardener committed Sep 13, 2018
1 parent eff48d0 commit 20192a9
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 37 deletions.
4 changes: 4 additions & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
Expand Down Expand Up @@ -1401,6 +1402,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
Expand Down Expand Up @@ -1787,6 +1789,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
Expand Down Expand Up @@ -2625,6 +2628,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
Expand Down
14 changes: 7 additions & 7 deletions tensorflow/compiler/xla/service/batchnorm_expander_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

namespace xla {
namespace {

using BatchNormExpanderTest = HloTestBase;
using BatchNormExpanderTest = HloVerifiedTestBase;

// Test that we expand BatchNormTraining.
TEST_F(BatchNormExpanderTest, BatchNormTraining) {
Expand Down Expand Up @@ -66,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
Expand Down Expand Up @@ -108,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
Expand All @@ -126,13 +126,13 @@ ENTRY entry {
epsilon=0.001, feature_index=1, sharding={maximal device=1}
})";

TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str));
ParseAndVerifyModule(module_str);
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie());

for (auto* instruction : module->entry_computation()->instructions()) {
for (auto* instruction : module().entry_computation()->instructions()) {
if (instruction->opcode() == HloOpcode::kParameter) {
continue;
}
Expand Down
12 changes: 7 additions & 5 deletions tensorflow/compiler/xla/service/call_inliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
Expand All @@ -40,7 +40,7 @@ namespace {

// Tests for call inlining that are most tractable at the HLO level (vs
// ComputationBuilder API in call_test.cc).
using CallInlinerTest = HloTestBase;
using CallInlinerTest = HloVerifiedTestBase;

TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
// "inner" computation just has a control dependency from the "zero" value to
Expand All @@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
auto computation = module->AddEntryComputation(outer.Build());

CallInliner call_inliner;
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
Expand All @@ -91,6 +91,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
module->AddEmbeddedComputation(just_false.Build());

HloComputation::Builder call_false_builder(TestName() + ".call_false");
call_false_builder.AddInstruction(
HloInstruction::CreateParameter(0, pred, "param"));
call_false_builder.AddInstruction(
HloInstruction::CreateCall(pred, {}, false_computation));
HloComputation* call_false =
Expand All @@ -105,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
auto computation = module->AddEntryComputation(outer.Build());

CallInliner call_inliner;
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(
computation->root_instruction()->while_condition()->root_instruction(),
Expand Down Expand Up @@ -161,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
module->AddEntryComputation(outer.Build());

CallInliner call_inliner;
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
}

Expand Down
33 changes: 16 additions & 17 deletions tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"

Expand All @@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {

using HloConstantFoldingTest = HloTestBase;
using HloConstantFoldingTest = HloVerifiedTestBase;

TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
HloComputation::Builder builder(TestName());
Expand All @@ -52,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));

HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);

EXPECT_THAT(computation->root_instruction(), op::Constant());
Expand All @@ -73,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));

HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);

EXPECT_THAT(computation->root_instruction(), op::Constant());
Expand All @@ -94,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));

HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);

EXPECT_THAT(computation->root_instruction(), op::Constant());
Expand Down Expand Up @@ -134,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
auto computation = module->AddEntryComputation(builder.Build());

HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);

HloInstruction* root = computation->root_instruction();
Expand All @@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
auto computation = module->AddEntryComputation(builder.Build());

HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);

HloInstruction* root = computation->root_instruction();
Expand All @@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
auto computation = module->AddEntryComputation(builder.Build());

HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);

HloInstruction* root = computation->root_instruction();
Expand Down Expand Up @@ -219,28 +219,27 @@ const char* const kConstantFoldReduce = R"(
})";

TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(kConstantFoldReduce));
ParseAndVerifyModule(kConstantFoldReduce);
HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_TRUE(result);

EXPECT_EQ(6, module->entry_computation()
EXPECT_EQ(6, module()
.entry_computation()
->root_instruction()
->literal()
.GetFirstElement<int32>());
}

TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(kConstantFoldReduce));
HloInstruction* add = module->computations().begin()->root_instruction();
ParseAndVerifyModule(kConstantFoldReduce);
HloInstruction* add = module().computations().begin()->root_instruction();
LayoutUtil::ClearLayout(add->mutable_shape());
HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_FALSE(result);

EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce());
}

} // namespace
Expand Down
16 changes: 8 additions & 8 deletions tensorflow/compiler/xla/service/inliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

Expand All @@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {

using InlinerTest = HloTestBase;
using InlinerTest = HloVerifiedTestBase;

// Test that `map` with `max` is transformed to `max`
TEST_F(InlinerTest, MapMax) {
Expand Down Expand Up @@ -64,12 +64,12 @@ TEST_F(InlinerTest, MapMax) {
hlo_module->AddEntryComputation(std::move(computation));

Inliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Maximum(lhs, rhs));

// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
Expand Down Expand Up @@ -98,12 +98,12 @@ TEST_F(InlinerTest, MapConstant) {
hlo_module->AddEntryComputation(std::move(computation));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
Inliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
root = hlo_module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Broadcast(op::Constant()));

// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
Expand Down Expand Up @@ -136,12 +136,12 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
hlo_module->AddEntryComputation(std::move(computation));

Inliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Subtract(rhs, lhs));

// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
Expand Down

0 comments on commit 20192a9

Please sign in to comment.