Skip to content

Commit

Permalink
[mlir][SliceAnalysis] Add an options object to forward and backward s…
Browse files Browse the repository at this point in the history
…lice.

Add an options object to allow control of the slice computation (for
both forward and backward slice). This makes the ABI stable, and also
allows avoiding an assert that makes the slice analysis unusable for
operations with multiple blocks.

Reviewed By: hanchung, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D151520
  • Loading branch information
Mahesh Ravishankar committed Jun 8, 2023
1 parent c1059dc commit 641b12e
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 66 deletions.
46 changes: 28 additions & 18 deletions mlir/include/mlir/Analysis/SliceAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,27 @@ class BlockArgument;
class Operation;
class Value;

/// Type of the condition to limit the propagation of transitive use-defs.
/// This can be used in particular to limit the propagation to a given Scope or
/// to avoid passing through certain types of operation in a configurable
/// manner.
using TransitiveFilter = llvm::function_ref<bool(Operation *)>;
struct SliceOptions {
/// Type of the condition to limit the propagation of transitive use-defs.
/// This can be used in particular to limit the propagation to a given Scope
/// or to avoid passing through certain types of operation in a configurable
/// manner.
using TransitiveFilter = std::function<bool(Operation *)>;
TransitiveFilter filter = nullptr;

/// Include the top level op in the slice.
bool inclusive = false;
};

struct BackwardSliceOptions : public SliceOptions {
/// When omitBlockArguments is true, the backward slice computation omits
/// traversing any block arguments. When omitBlockArguments is false, the
/// backward slice computation traverses block arguments and asserts that the
/// parent op has a single region with a single block.
bool omitBlockArguments = false;
};

using ForwardSliceOptions = SliceOptions;

/// Fills `forwardSlice` with the computed forward slice (i.e. all
/// the transitive uses of op), **without** including that operation.
Expand Down Expand Up @@ -69,14 +85,12 @@ using TransitiveFilter = llvm::function_ref<bool(Operation *)>;
/// {4, 3, 6, 2, 1, 5, 8, 7, 9}
///
void getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
TransitiveFilter filter = nullptr /* pass-through*/,
bool inclusive = false);
ForwardSliceOptions options = {});

/// Value-rooted version of `getForwardSlice`. Return the union of all forward
/// slices for the uses of the value `root`.
void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
TransitiveFilter filter = nullptr /* pass-through*/,
bool inclusive = false);
ForwardSliceOptions options = {});

/// Fills `backwardSlice` with the computed backward slice (i.e.
/// all the transitive defs of op), **without** including that operation.
Expand Down Expand Up @@ -113,14 +127,12 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
/// {1, 2, 5, 3, 4, 6}
///
void getBackwardSlice(Operation *op, SetVector<Operation *> *backwardSlice,
TransitiveFilter filter = nullptr /* pass-through*/,
bool inclusive = false);
BackwardSliceOptions options = {});

/// Value-rooted version of `getBackwardSlice`. Return the union of all backward
/// slices for the op defining or owning the value `root`.
void getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
TransitiveFilter filter = nullptr /* pass-through*/,
bool inclusive = false);
BackwardSliceOptions options = {});

/// Iteratively computes backward slices and forward slices until
/// a fixed point is reached. Returns an `SetVector<Operation *>` which
Expand Down Expand Up @@ -199,11 +211,9 @@ void getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
/// and keep things ordered but this is still hand-wavy and not worth the
/// trouble for now: punt to a simple worklist-based solution.
///
SetVector<Operation *>
getSlice(Operation *op,
TransitiveFilter backwardFilter = nullptr /* pass-through*/,
TransitiveFilter forwardFilter = nullptr /* pass-through*/,
bool inclusive = false);
SetVector<Operation *> getSlice(Operation *op,
BackwardSliceOptions backwardSliceOptions = {},
ForwardSliceOptions forwardSliceOptions = {});

/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.
Expand Down
55 changes: 30 additions & 25 deletions mlir/lib/Analysis/SliceAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

using namespace mlir;

static void getForwardSliceImpl(Operation *op,
SetVector<Operation *> *forwardSlice,
TransitiveFilter filter) {
static void
getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
SliceOptions::TransitiveFilter filter = nullptr) {
if (!op)
return;

Expand All @@ -51,9 +51,9 @@ static void getForwardSliceImpl(Operation *op,
}

void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
TransitiveFilter filter, bool inclusive) {
getForwardSliceImpl(op, forwardSlice, filter);
if (!inclusive) {
ForwardSliceOptions options) {
getForwardSliceImpl(op, forwardSlice, options.filter);
if (!options.inclusive) {
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
forwardSlice->remove(op);
Expand All @@ -67,9 +67,9 @@ void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
}

void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
TransitiveFilter filter, bool inclusive) {
SliceOptions options) {
for (Operation *user : root.getUsers())
getForwardSliceImpl(user, forwardSlice, filter);
getForwardSliceImpl(user, forwardSlice, options.filter);

// Reverse to get back the actual topological order.
// std::reverse does not work out of the box on SetVector and I want an
Expand All @@ -80,22 +80,25 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,

static void getBackwardSliceImpl(Operation *op,
SetVector<Operation *> *backwardSlice,
TransitiveFilter filter) {
BackwardSliceOptions options) {
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
return;

// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
// transitive backwardSlice in the current scope.
if (filter && !filter(op))
if (options.filter && !options.filter(op))
return;

for (const auto &en : llvm::enumerate(op->getOperands())) {
auto operand = en.value();
if (auto *definingOp = operand.getDefiningOp()) {
if (backwardSlice->count(definingOp) == 0)
getBackwardSliceImpl(definingOp, backwardSlice, filter);
getBackwardSliceImpl(definingOp, backwardSlice, options);
} else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
if (options.omitBlockArguments)
continue;

Block *block = blockArg.getOwner();
Operation *parentOp = block->getParentOp();
// TODO: determine whether we want to recurse backward into the other
Expand All @@ -104,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
if (parentOp && backwardSlice->count(parentOp) == 0) {
assert(parentOp->getNumRegions() == 1 &&
parentOp->getRegion(0).getBlocks().size() == 1);
getBackwardSliceImpl(parentOp, backwardSlice, filter);
getBackwardSliceImpl(parentOp, backwardSlice, options);
}
} else {
llvm_unreachable("No definingOp and not a block argument.");
Expand All @@ -116,30 +119,29 @@ static void getBackwardSliceImpl(Operation *op,

void mlir::getBackwardSlice(Operation *op,
SetVector<Operation *> *backwardSlice,
TransitiveFilter filter, bool inclusive) {
getBackwardSliceImpl(op, backwardSlice, filter);
BackwardSliceOptions options) {
getBackwardSliceImpl(op, backwardSlice, options);

if (!inclusive) {
if (!options.inclusive) {
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
backwardSlice->remove(op);
}
}

void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
TransitiveFilter filter, bool inclusive) {
BackwardSliceOptions options) {
if (Operation *definingOp = root.getDefiningOp()) {
getBackwardSlice(definingOp, backwardSlice, filter, inclusive);
getBackwardSlice(definingOp, backwardSlice, options);
return;
}
Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
getBackwardSlice(bbAargOwner, backwardSlice, filter, inclusive);
getBackwardSlice(bbAargOwner, backwardSlice, options);
}

SetVector<Operation *> mlir::getSlice(Operation *op,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter,
bool inclusive) {
BackwardSliceOptions backwardSliceOptions,
ForwardSliceOptions forwardSliceOptions) {
SetVector<Operation *> slice;
slice.insert(op);

Expand All @@ -150,12 +152,12 @@ SetVector<Operation *> mlir::getSlice(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter, inclusive);
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
slice.insert(backwardSlice.begin(), backwardSlice.end());

// Compute and insert the forwardSlice starting from currentOp.
forwardSlice.clear();
getForwardSlice(currentOp, &forwardSlice, forwardFilter, inclusive);
getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
Expand Down Expand Up @@ -225,8 +227,11 @@ static bool dependsOnCarriedVals(Value value,
Operation *ancestorOp) {
// Compute the backward slice of the value.
SetVector<Operation *> slice;
getBackwardSlice(value, &slice,
[&](Operation *op) { return !ancestorOp->isAncestor(op); });
BackwardSliceOptions sliceOptions;
sliceOptions.filter = [&](Operation *op) {
return !ancestorOp->isAncestor(op);
};
getBackwardSlice(value, &slice, sliceOptions);

// Check that none of the operands of the operations in the backward slice are
// loop iteration arguments, and neither is the value itself.
Expand Down
22 changes: 14 additions & 8 deletions mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
/// Return an unsorted slice handling scf.for region differently than
/// `getSlice`. In scf.for we only want to include as part of the slice elements
/// that are part of the use/def chain.
static SetVector<Operation *> getSliceContract(Operation *op,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter) {
static SetVector<Operation *>
getSliceContract(Operation *op, BackwardSliceOptions backwardSliceOptions,
ForwardSliceOptions forwardSliceOptions) {
SetVector<Operation *> slice;
slice.insert(op);
unsigned currentIndex = 0;
Expand All @@ -315,7 +315,7 @@ static SetVector<Operation *> getSliceContract(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
slice.insert(backwardSlice.begin(), backwardSlice.end());

// Compute and insert the forwardSlice starting from currentOp.
Expand All @@ -326,11 +326,11 @@ static SetVector<Operation *> getSliceContract(Operation *op,
// converted to matrix type.
if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
for (Value forOpResult : forOp.getResults())
getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
for (BlockArgument &arg : forOp.getRegionIterArgs())
getForwardSlice(arg, &forwardSlice, forwardFilter);
getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
} else {
getForwardSlice(currentOp, &forwardSlice, forwardFilter);
getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
}
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
Expand All @@ -346,16 +346,22 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
return llvm::any_of(op->getResultTypes(),
[](Type t) { return isa<VectorType>(t); });
};
BackwardSliceOptions backwardSliceOptions;
backwardSliceOptions.filter = hasVectorDest;

auto hasVectorSrc = [](Operation *op) {
return llvm::any_of(op->getOperandTypes(),
[](Type t) { return isa<VectorType>(t); });
};
ForwardSliceOptions forwardSliceOptions;
forwardSliceOptions.filter = hasVectorSrc;

SetVector<Operation *> opToConvert;
op->walk([&](vector::ContractionOp contract) {
if (opToConvert.contains(contract.getOperation()))
return;
SetVector<Operation *> dependentOps =
getSliceContract(contract, hasVectorDest, hasVectorSrc);
getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
// If any instruction cannot use MMA matrix type drop the whole
// chain. MMA matrix are stored in an opaque type so they cannot be used
// by all operations.
Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,22 @@ static void computeBackwardSlice(tensor::PadOp padOp,
scf::ForOp outermostEnclosingForOp,
SetVector<Operation *> &backwardSlice) {
DominanceInfo domInfo(outermostEnclosingForOp);
auto filter = [&](Operation *op) {
BackwardSliceOptions sliceOptions;
sliceOptions.filter = [&](Operation *op) {
return domInfo.dominates(outermostEnclosingForOp, op) &&
!padOp->isProperAncestor(op);
};
sliceOptions.inclusive = true;

// First, add the ops required to compute the region to the backwardSlice.
SetVector<Value> valuesDefinedAbove;
getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
valuesDefinedAbove);
for (Value v : valuesDefinedAbove) {
getBackwardSlice(v, &backwardSlice, filter, /*inclusive=*/true);
getBackwardSlice(v, &backwardSlice, sliceOptions);
}
// Then, add the backward slice from padOp itself.
getBackwardSlice(padOp.getOperation(), &backwardSlice, filter,
/*inclusive=*/true);
getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
}

//===----------------------------------------------------------------------===//
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,11 @@ void mlir::collapseParallelLoops(
// Return failure when any op fails to hoist.
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
SetVector<Operation *> forwardSlice;
getForwardSlice(
outer.getInductionVar(), &forwardSlice,
[&inner](Operation *op) { return op != inner.getOperation(); });
ForwardSliceOptions options;
options.filter = [&inner](Operation *op) {
return op != inner.getOperation();
};
getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
LogicalResult status = success();
SmallVector<Operation *, 8> toHoist;
for (auto &op : outer.getBody()->without_terminator()) {
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/IR/slice_multiple_blocks.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: mlir-opt --pass-pipeline="builtin.module(slice-analysis-test{omit-block-arguments=true})" %s | FileCheck %s

func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
%a = memref.alloc(%arg0, %arg2) : memref<?x?xf32>
%b = memref.alloc(%arg2, %arg1) : memref<?x?xf32>
cf.br ^bb1
^bb1() :
%c = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
%d = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
outs(%c : memref<?x?xf32>)
linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
outs(%d : memref<?x?xf32>)
memref.dealloc %c : memref<?x?xf32>
memref.dealloc %b : memref<?x?xf32>
memref.dealloc %a : memref<?x?xf32>
memref.dealloc %d : memref<?x?xf32>
return
}
// CHECK-LABEL: func @slicing_linalg_op__backward_slice__0
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-DAG: %[[A:.+]] = memref.alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
// CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
// CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
// CHECK: return

// CHECK-LABEL: func @slicing_linalg_op__backward_slice__1
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-DAG: %[[A:.+]] = memref.alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
// CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
// CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
// CHECK: return
Loading

0 comments on commit 641b12e

Please sign in to comment.