Skip to content

Commit

Permalink
[mlir][sparse] Using SparseTensorType in SparsePackOpConverter
Browse files Browse the repository at this point in the history
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D147465
  • Loading branch information
wrengr committed Apr 3, 2023
1 parent b0ba8fe commit 34c9c59
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1235,29 +1235,28 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
matchAndRewrite(PackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

const auto rtp = getRankedTensorType(op.getResult());
assert(isUniqueCOOType(rtp));
const auto stt = getSparseTensorType(op.getResult());
assert(isUniqueCOOType(stt));

SmallVector<Value> fields;
Location loc = op.getLoc();

foreachFieldAndTypeInSparseTensor(
rtp,
[&rewriter, &fields, &op, rtp,
stt,
[&rewriter, &fields, &op, stt,
loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
Level /*lvl*/, DimLevelType /*dlt*/) -> bool {
assert(fields.size() == fIdx);
auto enc = getSparseTensorEncoding(rtp);
Value field;
switch (fKind) {
case SparseTensorFieldKind::StorageSpec:
field = SparseTensorSpecifier::getInitValue(rewriter, loc, rtp);
field = SparseTensorSpecifier::getInitValue(rewriter, loc, stt);
break;
case SparseTensorFieldKind::PosMemRef: {
// TACO-style COO starts with a PosBuffer
// By creating a constant value for it, we avoid the complexity of
// memory management.
const auto posTp = enc.getPosType();
const auto posTp = stt.getPosType();
auto tensorType = RankedTensorType::get({2}, posTp);
auto memrefType = MemRefType::get(tensorType.getShape(),
tensorType.getElementType());
Expand Down Expand Up @@ -1306,13 +1305,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
return true;
});

MutSparseTensorDescriptor desc(rtp, fields);
MutSparseTensorDescriptor desc(stt, fields);
auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getValues(), 0);
// FIXME: should use `SparseTensorType::getLvlRank` in lieu of
// `RankedTensorType::getRank`, because the latter introduces dim/lvl
// ambiguity.
for (Level lvl = 0, lvlRank = rtp.getRank(); lvl < lvlRank; lvl++) {
const auto sh = rtp.getShape()[lvl];
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
// FIXME: dim/lvl confusion!
const auto sh = stt.getDimShape()[lvl];
assert(!ShapedType::isDynamic(sh));
desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh));
if (lvl == 0)
Expand Down

0 comments on commit 34c9c59

Please sign in to comment.