Skip to content

Commit

Permalink
[mlir][sparse] Updating Merger::foreachTensorLoopId to take `LatPoi…
Browse files Browse the repository at this point in the history
…ntId`

Since all callsites of `foreachTensorLoopId` would simply look up the `LatPointId` to extract its `BitVector`, it's cleaner to let the `Merger` handle that instead.  This seems to better capture the intent of the `foreachTensorLoopId` method, and improves decoupling (since it removes a place that leaks the implementation detail that we use `BitVector`).

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D146082
  • Loading branch information
wrengr committed Mar 15, 2023
1 parent 73a0195 commit b60de1d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,11 @@ class Merger {
/// for each `TensorLoopId` and passing it the corresponding tensor
/// identifier, level, and level-type.
void
foreachTensorLoopId(const BitVector &bits,
foreachTensorLoopId(LatPointId p,
function_ref<void(TensorLoopId, TensorId,
std::optional<Level>, DimLevelType)>
callback) const {
for (const TensorLoopId b : bits.set_bits())
for (const TensorLoopId b : latPoints[p].bits.set_bits())
callback(b, tensor(b), getLvl(b), getDimLevelType(b));
}

Expand Down
29 changes: 14 additions & 15 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1273,18 +1273,18 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,

SmallVector<TensorId> tids;
SmallVector<Level> lvls;
env.merger().foreachTensorLoopId(
env.lat(l0).bits, [&](TensorLoopId b, TensorId tid,
std::optional<Level> lvl, DimLevelType dlt) {
assert(env.merger().loop(b) == idx);
if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
needsUniv = true;
} else {
// sparse/singleton levels.
tids.push_back(tid);
lvls.push_back(*lvl);
}
});
env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
std::optional<Level> lvl,
DimLevelType dlt) {
assert(env.merger().loop(b) == idx);
if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
needsUniv = true;
} else {
// sparse/singleton levels.
tids.push_back(tid);
lvls.push_back(*lvl);
}
});

env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls);

Expand Down Expand Up @@ -1342,16 +1342,15 @@ static bool translateBitsToTidLvlPairs(
CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl<TensorId> &tids,
SmallVectorImpl<Level> &lvls, SmallVectorImpl<TensorId> &affineTids,
SmallVectorImpl<Level> &affineLvls, SmallVectorImpl<AffineExpr> &exps) {
const BitVector &all = env.lat(li).bits;
const BitVector &simple = env.lat(li).simple;
const TensorId outTid = env.merger().getOutTensorID();
const std::optional<Level> outLvl = env.merger().getLvl(outTid, ldx);

unsigned numloopCond = 0;
bool hasNonUnique = false;
env.merger().foreachTensorLoopId(
all, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
DimLevelType dlt) {
li, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
DimLevelType dlt) {
if (simple.test(b)) {
if (isUndefDLT(dlt)) {
// An undefined dlt in the lattices, we probably mean to
Expand Down

0 comments on commit b60de1d

Please sign in to comment.