Skip to content

Commit

Permalink
[xla:gpu] Support DUS for AddressComputationFusionRewriter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 618942742
  • Loading branch information
tyb0807 authored and tensorflower-gardener committed Mar 25, 2024
1 parent 80f5c63 commit c76b77e
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,14 @@ bool IsAlignedSlice(const Shape& src_shape, const Shape& dst_shape,
return true;
}

absl::InlinedVector<HloInstruction*, 8> GetSlicedOperandChains(
const HloInstruction* instr, bool dynamic) {
absl::InlinedVector<HloInstruction*, 8> sliced_operand_chains = {
absl::InlinedVector<HloInstruction*, 8> GetSlicedChains(
const HloInstruction* instr, bool dynamic,
absl::flat_hash_map<const HloInstruction*, const HloInstruction*>&
replacement_map) {
replacement_map[instr] = instr;
absl::InlinedVector<HloInstruction*, 8> dyn_slice_chains = {
const_cast<HloInstruction*>(instr)};
absl::InlinedVector<HloInstruction*, 8> dus_chain;
auto fusion = HloFusionAdaptor::ForComputation(instr->parent());
// This set is used to avoid duplicates in the matched results. It contains
// the matched instructions that we have seen so far.
Expand All @@ -147,12 +151,6 @@ absl::InlinedVector<HloInstruction*, 8> GetSlicedOperandChains(
if (processed_sliced_chain_set.contains(cur)) return true;
maybe_sliced_operand_chain.push_back(
const_cast<HloInstruction*>(cur));
// TODO(vuson): lift the first restriction by considering fusing other
// uses of the operand to reuse the address computation. Only worth it
// if other uses are also custom calls though.
// TODO(vuson): lift the second restriction by considering fusing the
// non-noop instructions to the computation if possible (i.e. for
// dynamic slices).
if (dynamic) {
if (const auto slice_instr =
DynCast<HloDynamicSliceInstruction>(cur)) {
Expand All @@ -171,6 +169,9 @@ absl::InlinedVector<HloInstruction*, 8> GetSlicedOperandChains(
}
}
}
// TODO(vuson): lift the first restriction by considering fusing other
// uses of the operand to reuse the address computation. Only worth it
// if other uses are also custom calls though.
return cur->user_count() > 1 || !IsNoOp(cur);
});
if (maybe_slice_adaptor == std::nullopt) continue;
Expand All @@ -180,14 +181,58 @@ absl::InlinedVector<HloInstruction*, 8> GetSlicedOperandChains(
// Even in the case of stopping at a match that has been processed, we
// still need to add instructions encountered in the sliced operand chain
// during the latest traversal.
sliced_operand_chains.insert(sliced_operand_chains.end(),
maybe_sliced_operand_chain.begin(),
maybe_sliced_operand_chain.end());
dyn_slice_chains.insert(dyn_slice_chains.end(),
maybe_sliced_operand_chain.begin(),
maybe_sliced_operand_chain.end());
processed_sliced_chain_set.insert(maybe_sliced_operand_chain.begin(),
maybe_sliced_operand_chain.end());
}
}
return sliced_operand_chains;

if (dynamic) {
for (auto* user : instr->users()) {
absl::InlinedVector<HloInstruction*, 4> maybe_sliced_user_chain;
bool dus_found = false;
auto maybe_dus_adaptor = HloFindIf(
{HloInstructionAdaptor(*user)}, *fusion,
[&](auto node) {
const HloInstruction* cur = &node.instruction();
// If the node is a match that has been processed, stop the
// traversal.
if (processed_sliced_chain_set.contains(cur)) return true;
maybe_sliced_user_chain.push_back(const_cast<HloInstruction*>(cur));
if (const auto slice_instr =
DynCast<HloDynamicUpdateSliceInstruction>(cur)) {
if (IsAlignedSlice(slice_instr->shape(),
slice_instr->operand(1)->shape(), nullptr)) {
dus_found = true;
replacement_map[instr] = cur;
return dus_found;
}
}
// TODO(vuson): lift the first restriction by considering fusing
// other uses of the user to reuse the address computation. Only
// worth it if other uses are also custom calls though.
return cur->user_count() > 1 || !IsNoOp(cur);
},
/*visit_operands=*/false);
if (maybe_dus_adaptor == std::nullopt) continue;
const auto& maybe_dus_instr = maybe_dus_adaptor->instruction();
if (dus_found || processed_sliced_chain_set.contains(&maybe_dus_instr)) {
// Even in the case of stopping at a match that has been processed, we
// still need to add instructions encountered in the sliced user chain
// during the latest traversal.
dus_chain.insert(dus_chain.end(), maybe_sliced_user_chain.rbegin(),
maybe_sliced_user_chain.rend());
processed_sliced_chain_set.insert(maybe_sliced_user_chain.begin(),
maybe_sliced_user_chain.end());
}
}
}

dus_chain.insert(dus_chain.end(), dyn_slice_chains.begin(),
dyn_slice_chains.end());
return dus_chain;
}

absl::InlinedVector<HloInstruction*, 4> GetPatternCaptures(
Expand Down Expand Up @@ -333,14 +378,17 @@ absl::StatusOr<bool> AddressComputationFusionRewriter::Run(
absl::flat_hash_map<HloInstruction*,
absl::InlinedVector<HloInstruction*, 8>>
matches;
absl::flat_hash_map<const HloInstruction*, const HloInstruction*>
replacement_map;

// Collect all potential custom call matches in the non-fusion computations.
for (HloComputation* computation : module->computations()) {
if (computation->IsFusionComputation()) continue;
for (HloInstruction* instr : computation->instructions()) {
if (IsLegacyCublasMatmul(*instr) ||
(!dynamic && IsCustomCall(instr, platform_name_))) {
auto sliced_operand_chains = GetSlicedOperandChains(instr, dynamic);
auto sliced_operand_chains =
GetSlicedChains(instr, dynamic, replacement_map);
if (!(sliced_operand_chains.size() == 1 &&
sliced_operand_chains.front() == instr)) {
matches[instr] = std::move(sliced_operand_chains);
Expand Down Expand Up @@ -373,17 +421,22 @@ absl::StatusOr<bool> AddressComputationFusionRewriter::Run(
sequence.replace_instruction(kv.first, fusion);

// TODO(vuson): handle control dependencies
TF_RETURN_IF_ERROR(parent->ReplaceInstruction(kv.first, fusion));
TF_RETURN_IF_ERROR(parent->ReplaceInstruction(
const_cast<HloInstruction*>(replacement_map[kv.first]), fusion));
}

TF_RETURN_IF_ERROR(module->schedule().Update());

return true;
};

TF_ASSIGN_OR_RETURN(bool static_sliced, process_slices(false));
TF_ASSIGN_OR_RETURN(bool dynamic_sliced, process_slices(true));
return static_sliced || dynamic_sliced;
// TODO(vuson): unify dynamic_address_computation and address_computation
TF_ASSIGN_OR_RETURN(bool processed_pattern_with_static_slices,
process_slices(false));
TF_ASSIGN_OR_RETURN(bool processed_pattern_with_dynamic_slices,
process_slices(true));
return processed_pattern_with_static_slices ||
processed_pattern_with_dynamic_slices;
}

} // namespace gpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1240,4 +1240,147 @@ TEST_F(AddressComputationFusionRewriterTest, DynamicSimpleGemmNotRoot) {
});
}

TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemm) {
const char* hlo = R"(
HloModule test, is_scheduled=true
ENTRY main.9 {
p0 = f16[1,8,8]{2,1,0} parameter(0)
p1 = f16[1,8,8]{2,1,0} parameter(1)
p2 = f16[4,8,8]{2,1,0} parameter(2)
c1_s32 = s32[] constant(1)
c0_s32 = s32[] constant(0)
bitcast.41 = f16[8,8]{1,0} bitcast(p0)
bitcast.42 = f16[8,8]{1,0} bitcast(p1)
custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
custom_call_target="__cublas$gemm",
backend_config={"gemm_backend_config":{
"alpha_real":1,
"beta":0,
"dot_dimension_numbers":{
"lhs_contracting_dimensions":["1"],
"rhs_contracting_dimensions":["0"],
"lhs_batch_dimensions":[],
"rhs_batch_dimensions":[]
},
"alpha_imag":0,
"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
"epilogue":"DEFAULT",
"lhs_stride":"64",
"rhs_stride":"64",
"grad_x":false,
"grad_y":false
}}
bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
ROOT dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
}
)";

const char* expected = R"(
; CHECK-DAG: [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(3)
; CHECK-DAG: [[P1:%[^ ]+]] = f16[8,8]{1,0} parameter(4)
; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(0)
; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
; CHECK-DAG: [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]]),
; CHECK-DAG: custom_call_target="__cublas$gemm"
; CHECK-DAG: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]])
; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
; CHECK: }
; CHECK: ENTRY %main{{.*}} {
; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion
; CHECK: kind=kCustom, calls=%address-computation,
; CHECK: backend_config={
; CHECK: "kind":"__custom_fusion",
; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"}
; CHECK: }
; CHECK: }
)";

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
});
}

TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemmNotRoot) {
const char* hlo = R"(
HloModule test, is_scheduled=true
ENTRY main.9 {
p0 = f16[2,8,8]{2,1,0} parameter(0)
p1 = f16[2,8,8]{2,1,0} parameter(1)
p2 = f16[4,8,8]{2,1,0} parameter(2)
c1_s32 = s32[] constant(1)
c0_s32 = s32[] constant(0)
slice.13 = f16[1,8,8]{2,1,0} dynamic-slice(p0, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
bitcast.41 = f16[8,8]{1,0} bitcast(slice.13)
slice.14 = f16[1,8,8]{2,1,0} dynamic-slice(p1, c1_s32, c0_s32, c0_s32), dynamic_slice_sizes={1,8,8}
bitcast.42 = f16[8,8]{1,0} bitcast(slice.14)
custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42),
custom_call_target="__cublas$gemm",
backend_config={"gemm_backend_config":{
"alpha_real":1,
"beta":0,
"dot_dimension_numbers":{
"lhs_contracting_dimensions":["1"],
"rhs_contracting_dimensions":["0"],
"lhs_batch_dimensions":[],
"rhs_batch_dimensions":[]
},
"alpha_imag":0,
"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},
"epilogue":"DEFAULT",
"lhs_stride":"64",
"rhs_stride":"64",
"grad_x":false,
"grad_y":false
}}
bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1)
dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, c1_s32, c0_s32, c0_s32)
ROOT res = f16[4,8,8]{2,1,0} log(dus)
}
)";

const char* expected = R"(
; CHECK: address-computation {{.*}} {
; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(3)
; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(4)
; CHECK-DAG: [[P2:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(0)
; CHECK-DAG: [[C1:%[^ ]+]] = s32[] parameter(1)
; CHECK-DAG: [[C0:%[^ ]+]] = s32[] parameter(2)
; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P0]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]])
; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} dynamic-slice([[P1]], [[C1]], [[C0]], [[C0]]), dynamic_slice_sizes={1,8,8}
; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]])
; CHECK-DAG: [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[B0]], [[B1]]),
; CHECK-DAG: custom_call_target="__cublas$gemm"
; CHECK-DAG: [[BC:%[^ ]+]] = f16[1,8,8]{2,1,0} bitcast([[CC]])
; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} dynamic-update-slice([[P2]], [[BC]], [[C1]], [[C0]], [[C0]])
; CHECK: }
; CHECK: ENTRY %main{{.*}} {
; CHECK: [[FUSION:%[^ ]+]] = f16[4,8,8]{2,1,0} fusion
; CHECK: kind=kCustom, calls=%address-computation,
; CHECK: backend_config={
; CHECK: "kind":"__custom_fusion",
; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"}
; CHECK: }
; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} log([[FUSION]])
; CHECK: }
)";

auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo();
RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM),
expected, [](HloModule* module) {
EXPECT_TRUE(module->has_schedule());
TF_CHECK_OK(module->schedule().Verify());
});
}

} // namespace xla::gpu

0 comments on commit c76b77e

Please sign in to comment.