Skip to content

Commit

Permalink
[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and st…
Browse files Browse the repository at this point in the history
…ore (llvm#87637)

This commit relaxes Mem2Reg's type equality requirement for the LLVM
dialect's load and store operations. For now, we only allow loads to be
promoted if the reaching definition can be casted into a value of the
target type.

For stores, the same conversion casting check is applied and we ensure
that their result is properly casted to the type of the memory slot.
This is necessary to satisfy assumptions of the general mem2reg pass, as
it creates block arguments with the types of the memory slot.

This relands llvm#87504
  • Loading branch information
Dinistro authored Apr 5, 2024
1 parent 1c683eb commit 974f1ee
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 4 deletions.
41 changes: 37 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,37 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}

/// Checks that two types are the same or can be cast into one another.
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
}

/// Constructs operations that convert `inputValue` into a new value of type
/// `targetType`. Assumes that this conversion is possible.
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
Value inputValue, Type targetType) {
if (inputValue.getType() == targetType)
return inputValue;

if (!isa<LLVM::LLVMPointerType>(targetType) &&
!isa<LLVM::LLVMPointerType>(inputValue.getType()))
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);

if (!isa<LLVM::LLVMPointerType>(targetType))
return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);

if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);

return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
inputValue);
}

Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
return getValue();
return createConversionSequence(rewriter, getLoc(), getValue(),
slot.elemType);
}

bool LLVM::LoadOp::canUsesBeRemoved(
Expand All @@ -138,15 +167,18 @@ bool LLVM::LoadOp::canUsesBeRemoved(
// be removed (provided it loads the exact stored value and is not
// volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getResult().getType() == slot.elemType && !getVolatile_();
areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
!getVolatile_();
}

DeletionKind LLVM::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
Value newResult = createConversionSequence(
rewriter, getLoc(), reachingDefinition, getResult().getType());
rewriter.replaceAllUsesWith(getResult(), newResult);
return DeletionKind::Delete;
}

Expand All @@ -161,7 +193,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
// fine, provided we are currently promoting its target value. Don't allow a
// store OF the slot pointer, only INTO the slot pointer.
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr && getValue().getType() == slot.elemType &&
getValue() != slot.ptr &&
areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
!getVolatile_();
}

Expand Down
247 changes: 247 additions & 0 deletions mlir/test/Dialect/LLVMIR/mem2reg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,250 @@ llvm.func @transitive_reaching_def() -> !llvm.ptr {
%3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
llvm.return %3 : !llvm.ptr
}

// -----

// CHECK-LABEL: @load_int_from_float
llvm.func @load_int_from_float() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : f32 to i32
// CHECK: llvm.return %[[BITCAST:.*]]
llvm.return %2 : i32
}

// -----

// CHECK-LABEL: @load_float_from_int
llvm.func @load_float_from_int() -> f32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : i32 to f32
// CHECK: llvm.return %[[BITCAST:.*]]
llvm.return %2 : f32
}

// -----

// CHECK-LABEL: @load_int_from_vector
llvm.func @load_int_from_vector() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x vector<2xi16> : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : vector<2xi16> to i32
// CHECK: llvm.return %[[BITCAST:.*]]
llvm.return %2 : i32
}

// -----

// LLVM arrays cannot be bitcasted, so the following cannot be promoted.

// CHECK-LABEL: @load_int_from_array
llvm.func @load_int_from_array() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x !llvm.array<2 x i16> : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK-NOT: llvm.bitcast
llvm.return %2 : i32
}

// -----

// CHECK-LABEL: @store_int_to_float
// CHECK-SAME: %[[ARG:.*]]: i32
llvm.func @store_int_to_float(%arg: i32) -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: llvm.return %[[ARG]]
llvm.return %2 : i32
}

// -----

// CHECK-LABEL: @store_float_to_int
// CHECK-SAME: %[[ARG:.*]]: f32
llvm.func @store_float_to_int(%arg: f32) -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
llvm.store %arg, %1 {alignment = 4 : i64} : f32, !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
// CHECK: llvm.return %[[BITCAST]]
llvm.return %2 : i32
}

// -----

// CHECK-LABEL: @store_int_to_vector
// CHECK-SAME: %[[ARG:.*]]: i32
llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
// CHECK: %[[BITCAST1:.*]] = llvm.bitcast %[[BITCAST0]] : vector<2xi16> to vector<4xi8>
// CHECK: llvm.return %[[BITCAST1]]
llvm.return %2 : vector<4xi8>
}

// -----

// CHECK-LABEL: @load_ptr_from_int
llvm.func @load_ptr_from_int() -> !llvm.ptr {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[CAST:.*]] = llvm.inttoptr %[[UNDEF]] : i64 to !llvm.ptr
// CHECK: llvm.return %[[CAST:.*]]
llvm.return %2 : !llvm.ptr
}

// -----

// CHECK-LABEL: @load_int_from_ptr
llvm.func @load_int_from_ptr() -> i64 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x !llvm.ptr {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[CAST:.*]] = llvm.ptrtoint %[[UNDEF]] : !llvm.ptr to i64
// CHECK: llvm.return %[[CAST:.*]]
llvm.return %2 : i64
}

// -----

// CHECK-LABEL: @load_ptr_addrspace_cast
llvm.func @load_ptr_addrspace_cast() -> !llvm.ptr<2> {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[UNDEF]] : !llvm.ptr<1> to !llvm.ptr<2>
// CHECK: llvm.return %[[CAST:.*]]
llvm.return %2 : !llvm.ptr<2>
}

// -----

// CHECK-LABEL: @stores_with_different_types
// CHECK-SAME: %[[ARG0:.*]]: i64
// CHECK-SAME: %[[ARG1:.*]]: f64
llvm.func @stores_with_different_types(%arg0: i64, %arg1: f64, %cond: i1) -> f64 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
llvm.cond_br %cond, ^bb1, ^bb2
^bb1:
llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
// CHECK: llvm.br ^[[BB3:.*]](%[[ARG0]]
llvm.br ^bb3
^bb2:
llvm.store %arg1, %1 {alignment = 4 : i64} : f64, !llvm.ptr
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG1]] : f64 to i64
// CHECK: llvm.br ^[[BB3]](%[[BITCAST]]
llvm.br ^bb3
// CHECK: ^[[BB3]](%[[BLOCK_ARG:.*]]: i64)
^bb3:
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[BLOCK_ARG]] : i64 to f64
// CHECK: llvm.return %[[BITCAST]]
llvm.return %2 : f64
}

// -----

// Verifies that stores with smaller bitsize inputs are not replaced. A trivial
// implementation will be incorrect due to endianness considerations.

// CHECK-LABEL: @stores_with_different_type_sizes
llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
llvm.cond_br %cond, ^bb1, ^bb2
^bb1:
llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
llvm.br ^bb3
^bb2:
llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
llvm.br ^bb3
^bb3:
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
llvm.return %2 : f64
}

// -----

// CHECK-LABEL: @load_smaller_int
llvm.func @load_smaller_int() -> i16 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i16
llvm.return %2 : i16
}

// -----

// CHECK-LABEL: @load_different_type_smaller
llvm.func @load_different_type_smaller() -> f32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
llvm.return %2 : f32
}

// -----

// This alloca is too small for the load, still, mem2reg should not touch it.

// CHECK-LABEL: @impossible_load
llvm.func @impossible_load() -> f64 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
llvm.return %2 : f64
}

// -----

// Verifies that mem2reg does not introduce address space casts of pointers
// with different bitsize.

module attributes { dlti.dl_spec = #dlti.dl_spec<
#dlti.dl_entry<!llvm.ptr<1>, dense<[32, 64, 64]> : vector<3xi64>>,
#dlti.dl_entry<!llvm.ptr<2>, dense<[64, 64, 64]> : vector<3xi64>>
>} {

// CHECK-LABEL: @load_ptr_addrspace_cast_different_size
llvm.func @load_ptr_addrspace_cast_different_size() -> !llvm.ptr<2> {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
llvm.return %2 : !llvm.ptr<2>
}
}

0 comments on commit 974f1ee

Please sign in to comment.