Skip to content

Commit

Permalink
[mlir][sparse] Adding new STEA::{with,without}DimSlices factories
Browse files Browse the repository at this point in the history
(These factories are used in downstream code, despite not being used within the MLIR codebase.)

Depends On D151513

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D151518
  • Loading branch information
wrengr committed May 30, 2023
1 parent 540d5e0 commit af2bec7
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,14 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// reset to the default, and all other fields inherited from `this`.
SparseTensorEncodingAttr withoutBitWidths() const;

/// Constructs a new encoding with the given dimSlices, and all
/// other fields inherited from `this`.
SparseTensorEncodingAttr withDimSlices(ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;

/// Constructs a new encoding with the dimSlices reset to the default,
/// and all other fields inherited from `this`.
SparseTensorEncodingAttr withoutDimSlices() const;

//
// Rank methods.
//
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ class SparseTensorType {
return withEncoding(enc.withoutBitWidths());
}

SparseTensorType
withDimSlices(ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
return withEncoding(enc.withDimSlices(dimSlices));
}

SparseTensorType withoutDimSlices() const {
return withEncoding(enc.withoutDimSlices());
}

//
// Other methods.
//
Expand Down
11 changes: 11 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,17 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
return withBitWidths(0, 0);
}

SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
getDimToLvl(), getPosWidth(),
getCrdWidth(), dimSlices);
}

SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
}

bool SparseTensorEncodingAttr::isAllDense() const {
return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1138,10 +1138,7 @@ class SparseExtractSliceConverter
// TODO: We should check these in ExtractSliceOp::verify.
if (!srcEnc || !dstEnc || !dstEnc.isSlice())
return failure();
assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes());
assert(srcEnc.getDimToLvl() == dstEnc.getDimToLvl());
assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth());
assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());

SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
Expand Down

0 comments on commit af2bec7

Please sign in to comment.