Skip to content

Commit

Permalink
[mlir] LLVM dialect: use addressof instead of constant to create func…
Browse files Browse the repository at this point in the history
…tion pointers

`llvm.mlir.constant` was originally introduced as an LLVM dialect counterpart
to `std.constant`. As such, it was supporting "function pointer" constants
derived from the symbol name. This is different from `std.constant` that allows
for creation of a "function" constant since MLIR, unlike LLVM IR, supports
this. Later, `llvm.mlir.addressof` was introduced as an Op that obtains a
constant pointer to a global in the LLVM dialect. It naturally extends to
functions (in LLVM IR, functions are globals) and should be used for defining
"function pointer" values instead.

Fixes PR46344.

Differential Revision: https://reviews.llvm.org/D82667
  • Loading branch information
ftynse committed Jun 29, 2020
1 parent e503851 commit cba733e
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 35 deletions.
34 changes: 29 additions & 5 deletions mlir/docs/Dialects/LLVM.md
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,28 @@ Bitwise reinterpretation: `bitcast <value>`.

Selection: `select <condition>, <lhs>, <rhs>`.

### Auxiliary MLIR operations

These operations do not have LLVM IR counterparts but are necessary to map LLVM
IR into MLIR. They should be prefixed with `llvm.mlir`.
### Auxiliary MLIR Operations for Constants and Globals

LLVM IR has broad support for first-class constants, which is not the case for
MLIR. Instead, constants are defined in MLIR as regular SSA values produced by
operations with specific traits. The LLVM dialect provides a set of operations
that model LLVM IR constants. These operations do not correspond to LLVM IR
instructions and are therefore prefixed with `llvm.mlir`.

Inline constants can be created by `llvm.mlir.constant`, which currently
supports integer, float, string or elements attributes (constant sturcts are not
currently supported). LLVM IR constant expressions are expected to be
constructed as sequences of regular operations on SSA values produced by
`llvm.mlir.constant`. Additionally, MLIR provides semantically-charged
operations `llvm.mlir.undef` and `llvm.mlir.null` for the corresponding
constants.

LLVM IR globals can be defined using `llvm.mlir.global` at the module level,
except for functions that are defined with `llvm.func`. Globals, both variables
and functions, can be accessed by taking their address with the
`llvm.mlir.addressof` operation, which produces a pointer to the named global,
unlike the `llvm.mlir.constant` that produces the value of the same type as the
constant.

#### `llvm.mlir.addressof`

Expand All @@ -328,11 +346,17 @@ Examples:

```mlir
func @foo() {
// Get the address of a global.
// Get the address of a global variable.
%0 = llvm.mlir.addressof @const : !llvm<"i32*">
// Use it as a regular pointer.
%1 = llvm.load %0 : !llvm<"i32*">
// Get the address of a function.
%2 = llvm.mlir.addressof @foo : !llvm<"void ()*">
// The function address can be used for indirect calls.
llvm.call %2() : () -> ()
}
// Define the global.
Expand Down
13 changes: 12 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,8 @@ def Linkage : LLVM_EnumAttr<
def LLVM_AddressOfOp
: LLVM_OneResultOp<"mlir.addressof">,
Arguments<(ins FlatSymbolRefAttr:$global_name)> {
let summary = "Creates a pointer pointing to a global or a function";

let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, LLVMType resType, "
"StringRef name, ArrayRef<NamedAttribute> attrs = {}", [{
Expand All @@ -586,13 +588,21 @@ def LLVM_AddressOfOp
"ArrayRef<NamedAttribute> attrs = {}", [{
build(builder, result,
global.getType().getPointerTo(global.addr_space().getZExtValue()),
global.sym_name(), attrs);}]>
global.sym_name(), attrs);}]>,

OpBuilder<"OpBuilder &builder, OperationState &result, LLVMFuncOp func, "
"ArrayRef<NamedAttribute> attrs = {}", [{
build(builder, result,
func.getType().getPointerTo(), func.getName(), attrs);}]>
];

let extraClassDeclaration = [{
/// Return the llvm.mlir.global operation that defined the value referenced
/// here.
GlobalOp getGlobal();

/// Return the llvm.func operation that is referenced here.
LLVMFuncOp getFunction();
}];

let assemblyFormat = "$global_name attr-dict `:` type($res)";
Expand Down Expand Up @@ -733,6 +743,7 @@ def LLVM_ConstantOp
LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);">
{
let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
let verifier = [{ return ::verify(*this); }];
}

def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]>,
Expand Down
37 changes: 34 additions & 3 deletions mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1366,8 +1366,6 @@ using AddFOpLowering = VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp>;
using AddIOpLowering = VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp>;
using AndOpLowering = VectorConvertToLLVMPattern<AndOp, LLVM::AndOp>;
using CeilFOpLowering = VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp>;
using ConstLLVMOpLowering =
OneToOneConvertToLLVMPattern<ConstantOp, LLVM::ConstantOp>;
using CopySignOpLowering =
VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp>;
using CosOpLowering = VectorConvertToLLVMPattern<CosOp, LLVM::CosOp>;
Expand Down Expand Up @@ -1541,6 +1539,39 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
}
};

struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = cast<ConstantOp>(operation);
// If constant refers to a function, convert it to "addressof".
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
auto type = typeConverter.convertType(op.getResult().getType())
.dyn_cast_or_null<LLVM::LLVMType>();
if (!type)
return rewriter.notifyMatchFailure(op, "failed to convert result type");

MutableDictionaryAttr attrs(op.getAttrs());
attrs.remove(rewriter.getIdentifier("value"));
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
op, type.cast<LLVM::LLVMType>(), symbolRef.getValue(),
attrs.getAttrs());
return success();
}

// Calling into other scopes (non-flat reference) is not supported in LLVM.
if (op.getValue().isa<SymbolRefAttr>())
return rewriter.notifyMatchFailure(
op, "referring to a symbol outside of the current module");

return LLVM::detail::oneToOneRewrite(op,
LLVM::ConstantOp::getOperationName(),
operands, typeConverter, rewriter);
}
};

// Check if the MemRefType `type` is supported by the lowering. We currently
// only support memrefs with identity maps.
static bool isSupportedMemRefType(MemRefType type) {
Expand Down Expand Up @@ -3129,7 +3160,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
CondBranchOpLowering,
CopySignOpLowering,
CosOpLowering,
ConstLLVMOpLowering,
ConstantOpLowering,
CreateComplexOpLowering,
DialectCastOpLowering,
DivFOpLowering,
Expand Down
45 changes: 36 additions & 9 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,25 +857,40 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
// Verifier for LLVM::AddressOfOp.
//===----------------------------------------------------------------------===//

GlobalOp AddressOfOp::getGlobal() {
Operation *module = getParentOp();
template <typename OpTy>
static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
Operation *module = parent;
while (module && !satisfiesLLVMModule(module))
module = module->getParentOp();
assert(module && "unexpected operation outside of a module");
return dyn_cast_or_null<LLVM::GlobalOp>(
mlir::SymbolTable::lookupSymbolIn(module, global_name()));
return dyn_cast_or_null<OpTy>(
mlir::SymbolTable::lookupSymbolIn(module, name));
}

GlobalOp AddressOfOp::getGlobal() {
return lookupSymbolInModule<LLVM::GlobalOp>(getParentOp(), global_name());
}

LLVMFuncOp AddressOfOp::getFunction() {
return lookupSymbolInModule<LLVM::LLVMFuncOp>(getParentOp(), global_name());
}

static LogicalResult verify(AddressOfOp op) {
auto global = op.getGlobal();
if (!global)
auto function = op.getFunction();
if (!global && !function)
return op.emitOpError(
"must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");

if (global &&
global.getType().getPointerTo(global.addr_space().getZExtValue()) !=
op.getResult().getType())
return op.emitOpError(
"must reference a global defined by 'llvm.mlir.global'");
"the type must be a pointer to the type of the referenced global");

if (global.getType().getPointerTo(global.addr_space().getZExtValue()) !=
op.getResult().getType())
if (function && function.getType().getPointerTo() != op.getResult().getType())
return op.emitOpError(
"the type must be a pointer to the type of the referred global");
"the type must be a pointer to the type of the referenced function");

return success();
}
Expand Down Expand Up @@ -1395,6 +1410,18 @@ static LogicalResult verify(LLVM::NullOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// Verification for LLVM::ConstantOp.
//===----------------------------------------------------------------------===//

static LogicalResult verify(LLVM::ConstantOp op) {
if (!(op.value().isa<IntegerAttr>() || op.value().isa<FloatAttr>() ||
op.value().isa<ElementsAttr>() || op.value().isa<StringAttr>()))
return op.emitOpError()
<< "only supports integer, float, string or elements attributes";
return success();
}

//===----------------------------------------------------------------------===//
// Utility functions for parsing atomic ops
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ Value Importer::processConstant(llvm::Constant *c) {
LLVMType type = processType(c->getType());
if (!type)
return nullptr;
if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
symbolRef.getValue());
return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
}
if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
Expand Down
9 changes: 7 additions & 2 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,15 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
// emit any LLVM instruction.
if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
LLVM::GlobalOp global = addressOfOp.getGlobal();
LLVM::LLVMFuncOp function = addressOfOp.getFunction();

// The verifier should not have allowed this.
assert(global && "referencing an undefined global");
assert((global || function) &&
"referencing an undefined global or function");

valueMapping[addressOfOp.getResult()] = globalsMapping.lookup(global);
valueMapping[addressOfOp.getResult()] =
global ? globalsMapping.lookup(global)
: functionMapping.lookup(function.getName());
return success();
}

Expand Down
9 changes: 5 additions & 4 deletions mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func @pass_through(%arg0: () -> ()) -> (() -> ()) {
// CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm<"void ()*">)
br ^bb1(%arg0 : () -> ())

//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">): // pred: ^bb0
//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">):
^bb1(%bbarg: () -> ()):
// CHECK-NEXT: llvm.return %0 : !llvm<"void ()*">
return %bbarg : () -> ()
Expand All @@ -40,11 +40,12 @@ func @pass_through(%arg0: () -> ()) -> (() -> ()) {
// CHECK-LABEL: llvm.func @body(!llvm.i32)
func @body(i32)

// CHECK-LABEL: llvm.func @indirect_const_call(%arg0: !llvm.i32) {
// CHECK-LABEL: llvm.func @indirect_const_call
// CHECK-SAME: (%[[ARG0:.*]]: !llvm.i32) {
func @indirect_const_call(%arg0: i32) {
// CHECK-NEXT: %0 = llvm.mlir.constant(@body) : !llvm<"void (i32)*">
// CHECK-NEXT: %[[ADDR:.*]] = llvm.mlir.addressof @body : !llvm<"void (i32)*">
%0 = constant @body : (i32) -> ()
// CHECK-NEXT: llvm.call %0(%arg0) : (!llvm.i32) -> ()
// CHECK-NEXT: llvm.call %[[ADDR]](%[[ARG0:.*]]) : (!llvm.i32) -> ()
call_indirect %0(%arg0) : (i32) -> ()
// CHECK-NEXT: llvm.return
return
Expand Down
15 changes: 12 additions & 3 deletions mlir/test/Dialect/LLVMIR/global.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,21 @@ func @foo() {
llvm.mlir.global internal @foo(0: i32) : !llvm.i32

func @bar() {
// expected-error @+1 {{the type must be a pointer to the type of the referred global}}
// expected-error @+1 {{the type must be a pointer to the type of the referenced global}}
llvm.mlir.addressof @foo : !llvm<"i64*">
}

// -----

llvm.func @foo()

llvm.func @bar() {
// expected-error @+1 {{the type must be a pointer to the type of the referenced function}}
llvm.mlir.addressof @foo : !llvm<"i8*">
}

// -----

// expected-error @+2 {{'llvm.mlir.global' op expects regions to end with 'llvm.return', found 'llvm.mlir.constant'}}
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'llvm.return'}}
llvm.mlir.global internal @g() : !llvm.i64 {
Expand All @@ -172,14 +181,14 @@ llvm.mlir.global internal @g(43 : i64) : !llvm.i64 {

llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : !llvm.i64
func @mismatch_addr_space_implicit_global() {
// expected-error @+1 {{op the type must be a pointer to the type of the referred global}}
// expected-error @+1 {{op the type must be a pointer to the type of the referenced global}}
llvm.mlir.addressof @g : !llvm<"i64*">
}

// -----

llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : !llvm.i64
func @mismatch_addr_space() {
// expected-error @+1 {{op the type must be a pointer to the type of the referred global}}
// expected-error @+1 {{op the type must be a pointer to the type of the referenced global}}
llvm.mlir.addressof @g : !llvm<"i64 addrspace(4)*">
}
7 changes: 7 additions & 0 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ func @call_non_llvm_input(%callee : (i32) -> (), %arg : i32) {

// -----

func @constant_wrong_type() {
// expected-error@+1 {{only supports integer, float, string or elements attributes}}
llvm.mlir.constant(@constant_wrong_type) : !llvm<"void ()*">
}

// -----

func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
// expected-error@+1 {{expected LLVM IR Dialect type}}
llvm.insertvalue %a, %b[0] : i32
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/LLVMIR/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ func @ops(%arg0: !llvm.i32, %arg1: !llvm.float,
// CHECK: %[[STRUCT:.*]] = llvm.call @foo(%[[I32]]) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
// CHECK: %[[VALUE:.*]] = llvm.extractvalue %[[STRUCT]][0] : !llvm<"{ i32, double, i32 }">
// CHECK: %[[NEW_STRUCT:.*]] = llvm.insertvalue %[[VALUE]], %[[STRUCT]][2] : !llvm<"{ i32, double, i32 }">
// CHECK: %[[FUNC:.*]] = llvm.mlir.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
// CHECK: %[[FUNC:.*]] = llvm.mlir.addressof @foo : !llvm<"{ i32, double, i32 } (i32)*">
// CHECK: %{{.*}} = llvm.call %[[FUNC]](%[[I32]]) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
%17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
%18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }">
%19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }">
%20 = llvm.mlir.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
%20 = llvm.mlir.addressof @foo : !llvm<"{ i32, double, i32 } (i32)*">
%21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">


Expand Down Expand Up @@ -130,8 +130,8 @@ func @ops(%arg0: !llvm.i32, %arg1: !llvm.float,
}

// An larger self-contained function.
// CHECK-LABEL: func @foo(%{{.*}}: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
// CHECK-LABEL: llvm.func @foo(%{{.*}}: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
llvm.func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
// CHECK: %[[V0:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i32
// CHECK: %[[V1:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i32
// CHECK: %[[V2:.*]] = llvm.mlir.constant(4.200000e+01 : f64) : !llvm.double
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Target/import.ll
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ define void @FPArithmetic(float %a, float %b, double %c, double %d) {
; CHECK-LABEL: @precaller
define i32 @precaller() {
%1 = alloca i32 ()*
; CHECK: %[[func:.*]] = llvm.mlir.constant(@callee) : !llvm<"i32 ()*">
; CHECK: %[[func:.*]] = llvm.mlir.addressof @callee : !llvm<"i32 ()*">
; CHECK: llvm.store %[[func]], %[[loc:.*]]
store i32 ()* @callee, i32 ()** %1
; CHECK: %[[indir:.*]] = llvm.load %[[loc]]
Expand All @@ -252,7 +252,7 @@ define i32 @callee() {
; CHECK-LABEL: @postcaller
define i32 @postcaller() {
%1 = alloca i32 ()*
; CHECK: %[[func:.*]] = llvm.mlir.constant(@callee) : !llvm<"i32 ()*">
; CHECK: %[[func:.*]] = llvm.mlir.addressof @callee : !llvm<"i32 ()*">
; CHECK: llvm.store %[[func]], %[[loc:.*]]
store i32 ()* @callee, i32 ()** %1
; CHECK: %[[indir:.*]] = llvm.load %[[loc]]
Expand Down Expand Up @@ -317,4 +317,4 @@ define i32 @useFenceInst() {
;CHECK: llvm.fence seq_cst
fence syncscope("") seq_cst
ret i32 0
}
}
2 changes: 1 addition & 1 deletion mlir/test/Target/llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ llvm.func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3:
// CHECK-LABEL: define void @indirect_const_call(i64 {{%.*}})
llvm.func @indirect_const_call(%arg0: !llvm.i64) {
// CHECK-NEXT: call void @body(i64 %0)
%0 = llvm.mlir.constant(@body) : !llvm<"void (i64)*">
%0 = llvm.mlir.addressof @body : !llvm<"void (i64)*">
llvm.call %0(%arg0) : (!llvm.i64) -> ()
// CHECK-NEXT: ret void
llvm.return
Expand Down

0 comments on commit cba733e

Please sign in to comment.