Skip to content

Commit

Permalink
[LLVMGPU] allow multiple m and n dims in contraction distribution (ir…
Browse files Browse the repository at this point in the history
…ee-org#16943)

This adjusts the layout generation logic to allow distribution of
contractions with multiple m and n dimensions by greedily using the
subgroup/tile_counts of the mma_schedule with the outer dims. The inner
most m/n dimensions are still required to be divisible by the intrinsic
shape. (and this only supports a single k dimension).

This also decouples the ordering logic of the batch/subgroup
distribution from the lane distribution for the intrinsics. Currently it
assumes intrinsics can only specify three important sizes, an M, N, and
K size. To support distributed batches this would require adding a
fourth dim type.
  • Loading branch information
qedawkins authored Apr 19, 2024
1 parent 9f97989 commit e87ff17
Show file tree
Hide file tree
Showing 12 changed files with 632 additions and 207 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
LLVM_DEBUG(llvm::dbgs() << "init tile: " << finalTile << "\n");

// Offsets into the LHS/RHS batches.
SmallVector<int64_t, 2> lhsBatchOffsets(rank, 0);
SmallVector<int64_t, 2> rhsBatchOffsets(rank, 0);
SmallVector<int64_t, 2> lhsBatchOffsets(lhsLayout.getRank(), 0);
SmallVector<int64_t, 2> rhsBatchOffsets(rhsLayout.getRank(), 0);

// Offsets into the result batches.
ArrayRef<int64_t> resultBatches = resultLayout.getBatchesPerSubgroup();
Expand Down Expand Up @@ -183,7 +183,7 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
std::optional<int64_t> getKBatchSize(const VectorContractOpInfo &opDetail,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout) const {
auto [lhsK, rhsK] = *opDetail.getOperandKIndex();
auto [lhsK, rhsK] = opDetail.getOperandKIndex();
int64_t lhsKBatch = lhsLayout.getBatchesPerSubgroup()[lhsK];
int64_t rhsKBatch = rhsLayout.getBatchesPerSubgroup()[rhsK];

Expand All @@ -201,15 +201,21 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
SmallVector<int64_t, 2> &rhsOffsets,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout) const {
auto [lhsM, rhsN] = *opDetail.getOperandMNIndex();
auto [lhsK, rhsK] = *opDetail.getOperandKIndex();
auto [resultM, resultN] = *opDetail.getResultMNIndex();

// resultOffsets contains batch indices into the C/D vector. It is a 2-D
// index for both M and N. We need to split out for M and N, and add index
// for K.
lhsOffsets[lhsM] = resultOffsets[resultM];
for (auto [lhsM, resultM] :
llvm::zip_equal(opDetail.lhsMDims, opDetail.outMDims)) {
lhsOffsets[lhsM] = resultOffsets[resultM];
}
for (auto [rhsN, resultN] :
llvm::zip_equal(opDetail.rhsNDims, opDetail.outNDims)) {
rhsOffsets[rhsN] = resultOffsets[resultN];
}

auto [lhsK, rhsK] = opDetail.getOperandKIndex();
lhsOffsets[lhsK] = kOffset;
rhsOffsets[rhsN] = resultOffsets[resultN];
rhsOffsets[rhsK] = kOffset;

// Now apply permutation on LHS/RHS according to their batch order.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,13 @@ builtin.module attributes { transform.with_named_sequence } {
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.0 : f16
%cst0_1 = arith.constant dense<0.0> : vector<16xf16>
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [false, true], thread_basis = [4, 16], thread_active_ids = [false, true]}}
%root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true], "__vector_layout_test_anchor_result_0" = #layout} : memref<16x16xf16>, vector<16x16xf16>
// expected-remark @above {{thread_basis = [4, 16]}}
%root_red = vector.multi_reduction<add>, %root, %cst0_1 [0] : vector<16x16xf16> to vector<16xf16>
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [false, true], thread_basis = [4, 16], thread_active_ids = [false, true]}}
%c = arith.mulf %root_red, %a : vector<16xf16>
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [false, true], thread_basis = [4, 16], thread_active_ids = [false, true]}}
func.return %c : vector<16xf16>
}

Expand Down Expand Up @@ -281,13 +281,13 @@ builtin.module attributes { transform.with_named_sequence } {
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.0 : f16
%cst0_1 = arith.constant dense<0.0> : vector<16xf16>
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [true, false], thread_basis = [4, 16], thread_active_ids = [true, false]}}
%root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true], "__vector_layout_test_anchor_result_0" = #layout} : memref<16x16xf16>, vector<16x16xf16>
// expected-remark @above {{thread_basis = [4, 16]}}
%root_red = vector.multi_reduction<add>, %root, %cst0_1 [1] : vector<16x16xf16> to vector<16xf16>
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [true, false], thread_basis = [4, 16], thread_active_ids = [true, false]}}
%c = arith.mulf %root_red, %a : vector<16xf16>
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
// expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids = [true, false], thread_basis = [4, 16], thread_active_ids = [true, false]}}
func.return %c : vector<16xf16>
}

Expand Down
Loading

0 comments on commit e87ff17

Please sign in to comment.