Skip to content

Commit

Permalink
[XLA] Modify the function that determines whether an instruction can …
Browse files Browse the repository at this point in the history
…change

layout so that it can be used by the HLO verifier.

Change the function to a static member function of the LayoutAssignment class.

Add an std::function member to LayoutAssignment to store the function object
passed down from the backend compiler class and use it to decide whether an
instruction can change layouts.

Fix affected test cases.

PiperOrigin-RevId: 215515611
  • Loading branch information
bixia1 authored and tensorflower-gardener committed Oct 3, 2018
1 parent 65b5190 commit bbe15ee
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 26 deletions.
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);

pipeline.AddPass<CpuLayoutAssignment>(
module->mutable_entry_computation_layout(), target_machine_features);
module->mutable_entry_computation_layout(),
LayoutAssignment::InstructionCanChangeLayout, target_machine_features);
return pipeline.Run(module).status();
}

Expand Down
5 changes: 4 additions & 1 deletion tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ class CpuLayoutAssignment : public LayoutAssignment {
public:
explicit CpuLayoutAssignment(
ComputationLayout* entry_computation_layout,
std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func,
const TargetMachineFeatures* target_machine_features)
: LayoutAssignment(entry_computation_layout),
: LayoutAssignment(entry_computation_layout,
std::move(instruction_can_change_layout_func)),
target_machine_features_(*target_machine_features) {}
~CpuLayoutAssignment() override {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ class CpuLayoutAssignmentTest : public HloTestBase {
[](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout,
&target_machine_features);
cpu::CpuLayoutAssignment layout_assignment(
entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
&target_machine_features);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
};
Expand Down Expand Up @@ -321,8 +322,9 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
[](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
cpu::CpuLayoutAssignment layout_assignment(&computation_layout,
&target_machine_features);
cpu::CpuLayoutAssignment layout_assignment(
&computation_layout, LayoutAssignment::InstructionCanChangeLayout,
&target_machine_features);
TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something,
layout_assignment.Run(module));

Expand Down
5 changes: 4 additions & 1 deletion tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ namespace gpu {
class GpuLayoutAssignment : public LayoutAssignment {
public:
explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout,
std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func,
se::StreamExecutor* stream_executor)
: LayoutAssignment(entry_computation_layout),
: LayoutAssignment(entry_computation_layout,
std::move(instruction_can_change_layout_func)),
stream_executor_(stream_executor) {}
~GpuLayoutAssignment() override {}

Expand Down
17 changes: 11 additions & 6 deletions tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
ShapeLayout(result_shape_with_layout);

GpuLayoutAssignment layout_assignment(
&computation_layout, backend().default_stream_executor());
&computation_layout, LayoutAssignment::InstructionCanChangeLayout,
backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());

for (const HloInstruction* operand : add->operands()) {
Expand Down Expand Up @@ -163,7 +164,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
}

GpuLayoutAssignment layout_assignment(
&computation_layout, backend().default_stream_executor());
&computation_layout, LayoutAssignment::InstructionCanChangeLayout,
backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());

// The first operand to batchnorm should have the same layout as the
Expand Down Expand Up @@ -233,7 +235,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
}

GpuLayoutAssignment layout_assignment(
&computation_layout, backend().default_stream_executor());
&computation_layout, LayoutAssignment::InstructionCanChangeLayout,
backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());

// The first operand to batchnorm should have the same layout as the
Expand Down Expand Up @@ -314,7 +317,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
}

GpuLayoutAssignment layout_assignment(
&computation_layout, backend().default_stream_executor());
&computation_layout, LayoutAssignment::InstructionCanChangeLayout,
backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());

// The first and fourth operands to the batchnorm call should have the
Expand Down Expand Up @@ -348,8 +352,9 @@ TEST_F(LayoutAssignmentTest, DotLayout) {

ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
GpuLayoutAssignment layout_assignment(&computation_layout,
backend().default_stream_executor());
GpuLayoutAssignment layout_assignment(
&computation_layout, LayoutAssignment::InstructionCanChangeLayout,
backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());

Shape expected_shape =
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// a layout-sensitive verifier!
HloPassPipeline pipeline("layout assignment");
pipeline.AddPass<GpuLayoutAssignment>(
hlo_module->mutable_entry_computation_layout(), stream_exec);
hlo_module->mutable_entry_computation_layout(),
LayoutAssignment::InstructionCanChangeLayout, stream_exec);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}

Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/interpreter/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
HloPassPipeline pipeline("Interpreter");

pipeline.AddPass<LayoutAssignment>(
hlo_module->mutable_entry_computation_layout());
hlo_module->mutable_entry_computation_layout(),
LayoutAssignment::InstructionCanChangeLayout);
return pipeline.Run(hlo_module).status();
}

Expand Down
18 changes: 12 additions & 6 deletions tensorflow/compiler/xla/service/layout_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -974,10 +974,15 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {

LayoutAssignment::LayoutAssignment(
ComputationLayout* entry_computation_layout,
std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func,
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),

saved_entry_computation_layout_(*entry_computation_layout),
channel_layout_constraints_(channel_constraints) {
channel_layout_constraints_(channel_constraints),
instruction_can_change_layout_func_(
std::move(instruction_can_change_layout_func)) {
if (channel_layout_constraints_ != nullptr) {
// Save a copy of the input ChannelLayoutConstraints so that we can reset it
// if we have to undo previous operations (ClearPreviousPassSideEffects()).
Expand All @@ -998,7 +1003,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
if (!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) ==
ShapeUtil::Rank(instruction->shape()) &&
InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) {
!instruction_can_change_layout_func_(instruction)) {
// Propagate the result layout to the operand layout if the instruction
// requires the same layout out for the result and the operand.
//
Expand Down Expand Up @@ -1076,7 +1081,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(

if (!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) &&
InstructionRequiresInputLayoutEqualToOutputLayout(user)) {
!instruction_can_change_layout_func_(user)) {
// Assign users the same layout as the operand.
return absl::make_unique<Layout>(operand_layout);
}
Expand Down Expand Up @@ -1842,7 +1847,8 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
return true;
}

bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout(
/* static */
bool LayoutAssignment::InstructionCanChangeLayout(
const HloInstruction* instruction) {
switch (instruction->opcode()) {
case HloOpcode::kAbs:
Expand Down Expand Up @@ -1908,7 +1914,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout(
case HloOpcode::kTanh:
case HloOpcode::kTupleSelect:
case HloOpcode::kWhile:
return true;
return false;
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
Expand Down Expand Up @@ -1939,7 +1945,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout(
case HloOpcode::kTrace:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
return false;
return true;
}
}

Expand Down
18 changes: 14 additions & 4 deletions tensorflow/compiler/xla/service/layout_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ class LayoutAssignment : public HloModulePass {
// entry_computation_layout is modified to populate a layout for the result in
// the case that no particular layout is requested.
//
// instruction_can_change_layout_func is a function object that determines
// whether an instruction can change layouts. An instruction not being able to
// change layout means that it requires operands with the same rank as the
// output to have the same layout as the output.
//
// channel_constraints is both an input and output. Any sends or recvs that
// are present in channel_constraints will be laid out as constrained. Any
// unconstrained sends or recvs will be laid out as locally optimal and their
Expand All @@ -295,6 +300,8 @@ class LayoutAssignment : public HloModulePass {
// within any module passed to `Run`.
explicit LayoutAssignment(
ComputationLayout* entry_computation_layout,
std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func = InstructionCanChangeLayout,
ChannelLayoutConstraints* channel_constraints = nullptr);
~LayoutAssignment() override {}
absl::string_view name() const override { return "layout-assignment"; }
Expand All @@ -303,10 +310,10 @@ class LayoutAssignment : public HloModulePass {
// (any layouts were changed).
StatusOr<bool> Run(HloModule* module) override;

// Returns true if the instruction requires that operands with the same rank
// as the output have to have the same layout as the output.
virtual bool InstructionRequiresInputLayoutEqualToOutputLayout(
const HloInstruction* instruction);
// Determines whether an instruction can change layouts. An instruction not
// being able to change layout means that it requires operands with the same
// rank as the output to have the same layout as the output.
static bool InstructionCanChangeLayout(const HloInstruction* instruction);

protected:
// These methods, invoked by PropagateConstraints, propagate a layout
Expand Down Expand Up @@ -522,6 +529,9 @@ class LayoutAssignment : public HloModulePass {
// The set of HLO instructions which lacked any layout constraint, thus
// receiving propagated default layouts.
absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_;

std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func_;
};

} // namespace xla
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/layout_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class LayoutAssignmentTest : public HloVerifiedTestBase {
ComputationLayout* entry_computation_layout,
ChannelLayoutConstraints* channel_constraints = nullptr) {
LayoutAssignment layout_assignment(
entry_computation_layout, /*channel_constraints=*/channel_constraints);
entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
/*channel_constraints=*/channel_constraints);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}

Expand Down

0 comments on commit bbe15ee

Please sign in to comment.