Skip to content

Commit

Permalink
[Coroutines] Part 10: Add coroutine promise support.
Browse files Browse the repository at this point in the history
Summary:
1) CoroEarly now lowers llvm.coro.promise intrinsic that allows to obtain
a coroutine promise pointer from a coroutine frame and vice versa.

2) CoroFrame now interprets Promise argument of llvm.coro.begin to
place CoroutinPromise alloca at a deterministic offset from the coroutine frame.

Now, the coroutine promise example from docs\Coroutines.rst compiles and produces expected result (see test/Transform/Coroutines/ex4.ll).

Reviewers: majnemer

Subscribers: llvm-commits, mehdi_amini

Differential Revision: https://reviews.llvm.org/D23993

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@280184 91177308-0d34-0410-b5e6-96231b3b80d8
  • Loading branch information
GorNishanov committed Aug 31, 2016
1 parent 1a31aef commit b6a1398
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 9 deletions.
42 changes: 41 additions & 1 deletion lib/Transforms/Coroutines/CoroEarly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "CoroInternal.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
Expand All @@ -24,10 +25,18 @@ using namespace llvm;
namespace {
// Created on demand if CoroEarly pass has work to do.
class Lowerer : public coro::LowererBase {
IRBuilder<> Builder;
PointerType *AnyResumeFnPtrTy;

void lowerResumeOrDestroy(CallSite CS, CoroSubFnInst::ResumeKind);
void lowerCoroPromise(CoroPromiseInst *Intrin);

public:
Lowerer(Module &M) : LowererBase(M) {}
Lowerer(Module &M)
: LowererBase(M), Builder(Context),
AnyResumeFnPtrTy(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
/*isVarArg=*/false)
->getPointerTo()) {}
bool lowerEarlyIntrinsics(Function &F);
};
}
Expand All @@ -44,6 +53,34 @@ void Lowerer::lowerResumeOrDestroy(CallSite CS,
CS.setCallingConv(CallingConv::Fast);
}

// Coroutine promise field is always at the fixed offset from the beginning of
// the coroutine frame. i8* coro.promise(i8*, i1 from) intrinsic adds an offset
// to a passed pointer to move from coroutine frame to coroutine promise and
// vice versa. Since we don't know exactly which coroutine frame it is, we build
// a coroutine frame mock up starting with two function pointers, followed by a
// properly aligned coroutine promise field.
// TODO: Handle the case when coroutine promise alloca has align override.
void Lowerer::lowerCoroPromise(CoroPromiseInst *Intrin) {
Value *Operand = Intrin->getArgOperand(0);
unsigned Alignement = Intrin->getAlignment();
Type *Int8Ty = Builder.getInt8Ty();

auto *SampleStruct =
StructType::get(Context, {AnyResumeFnPtrTy, AnyResumeFnPtrTy, Int8Ty});
const DataLayout &DL = TheModule.getDataLayout();
int64_t Offset = alignTo(
DL.getStructLayout(SampleStruct)->getElementOffset(2), Alignement);
if (Intrin->isFromPromise())
Offset = -Offset;

Builder.SetInsertPoint(Intrin);
Value *Replacement =
Builder.CreateConstInBoundsGEP1_32(Int8Ty, Operand, Offset);

Intrin->replaceAllUsesWith(Replacement);
Intrin->eraseFromParent();
}

// Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate,
// as CoroSplit assumes there is exactly one coro.begin. After CoroSplit,
// NoDuplicate attribute will be removed from coro.begin otherwise, it will
Expand Down Expand Up @@ -91,6 +128,9 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) {
case Intrinsic::coro_destroy:
lowerResumeOrDestroy(CS, CoroSubFnInst::DestroyIndex);
break;
case Intrinsic::coro_promise:
lowerCoroPromise(cast<CoroPromiseInst>(&I));
break;
}
Changed = true;
}
Expand Down
29 changes: 25 additions & 4 deletions lib/Transforms/Coroutines/CoroFrame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,11 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape,

// Figure out how wide should be an integer type storing the suspend index.
unsigned IndexBits = std::max(1U, Log2_64_Ceil(Shape.CoroSuspends.size()));

SmallVector<Type *, 8> Types{FnPtrTy, FnPtrTy, Type::getIntNTy(C, IndexBits)};
Type *PromiseType = Shape.PromiseAlloca
? Shape.PromiseAlloca->getType()->getElementType()
: Type::getInt1Ty(C);
SmallVector<Type *, 8> Types{FnPtrTy, FnPtrTy, PromiseType,
Type::getIntNTy(C, IndexBits)};
Value *CurrentDef = nullptr;

// Create an entry for every spilled value.
Expand All @@ -321,6 +324,9 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape,
continue;

CurrentDef = S.def();
// PromiseAlloca was already added to Types array earlier.
if (CurrentDef == Shape.PromiseAlloca)
continue;

Type *Ty = nullptr;
if (auto *AI = dyn_cast<AllocaInst>(CurrentDef))
Expand Down Expand Up @@ -376,6 +382,9 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
// we remember allocas and their indices to be handled once we processed
// all the spills.
SmallVector<std::pair<AllocaInst *, unsigned>, 4> Allocas;
// Promise alloca (if present) has a fixed field number (Shape::PromiseField)
if (Shape.PromiseAlloca)
Allocas.emplace_back(Shape.PromiseAlloca, coro::Shape::PromiseField);

// Create a load instruction to reload the spilled value from the coroutine
// frame.
Expand All @@ -400,7 +409,7 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
++Index;

if (auto *AI = dyn_cast<AllocaInst>(CurrentValue)) {
// Spiled AllocaInst will be replaced with GEP from the coroutine frame
// Spilled AllocaInst will be replaced with GEP from the coroutine frame
// there is no spill required.
Allocas.emplace_back(AI, Index);
if (!AI->isStaticAlloca())
Expand Down Expand Up @@ -444,7 +453,11 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
for (auto &P : Allocas) {
auto *G =
Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, P.second);
ReplaceInstWithInst(P.first, cast<Instruction>(G));
// We are not using ReplaceInstWithInst(P.first, cast<Instruction>(G)) here,
// as we are changing location of the instruction.
G->takeName(P.first);
P.first->replaceAllUsesWith(G);
P.first->eraseFromParent();
}
return FramePtr;
}
Expand Down Expand Up @@ -568,6 +581,10 @@ static void splitAround(Instruction *I, const Twine &Name) {
}

void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
Shape.PromiseAlloca = Shape.CoroBegin->getId()->getPromise();
if (Shape.PromiseAlloca) {
Shape.CoroBegin->getId()->clearPromise();
}

// Make sure that all coro.saves and the fallthrough coro.end are in their
// own block to simplify the logic of building up SuspendCrossing data.
Expand Down Expand Up @@ -621,6 +638,10 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
// in a coroutine. It should not be saved to the coroutine frame.
if (isa<CoroIdInst>(&I))
continue;
// The Coroutine Promise always included into coroutine frame, no need to
// check for suspend crossing.
if (Shape.PromiseAlloca == &I)
continue;

for (User *U : I.users())
if (Checker.isDefinitionAcrossSuspend(I, U)) {
Expand Down
54 changes: 54 additions & 0 deletions lib/Transforms/Coroutines/CoroInstr.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,39 @@ class LLVM_LIBRARY_VISIBILITY CoroIdInst : public IntrinsicInst {
enum { AlignArg, PromiseArg, CoroutineArg, InfoArg };

public:
IntrinsicInst *getCoroBegin() {
for (User *U : users())
if (auto *II = dyn_cast<IntrinsicInst>(U))
if (II->getIntrinsicID() == Intrinsic::coro_begin)
return II;
llvm_unreachable("no coro.begin associated with coro.id");
}

AllocaInst *getPromise() const {
Value *Arg = getArgOperand(PromiseArg);
return isa<ConstantPointerNull>(Arg)
? nullptr
: cast<AllocaInst>(Arg->stripPointerCasts());
}

void clearPromise() {
Value *Arg = getArgOperand(PromiseArg);
setArgOperand(PromiseArg,
ConstantPointerNull::get(Type::getInt8PtrTy(getContext())));
if (isa<AllocaInst>(Arg))
return;
assert((isa<BitCastInst>(Arg) || isa<GetElementPtrInst>(Arg)) &&
"unexpected instruction designating the promise");
// TODO: Add a check that any remaining users of Inst are after coro.begin
// or add code to move the users after coro.begin.
auto *Inst = cast<Instruction>(Arg);
if (Inst->use_empty()) {
Inst->eraseFromParent();
return;
}
Inst->moveBefore(getCoroBegin()->getNextNode());
}

// Info argument of coro.id is
// fresh out of the frontend: null ;
// outlined : {Init, Return, Susp1, Susp2, ...} ;
Expand Down Expand Up @@ -198,6 +231,27 @@ class LLVM_LIBRARY_VISIBILITY CoroSaveInst : public IntrinsicInst {
}
};

/// This represents the llvm.coro.promise instruction.
class LLVM_LIBRARY_VISIBILITY CoroPromiseInst : public IntrinsicInst {
enum { FrameArg, AlignArg, FromArg };

public:
bool isFromPromise() const {
return cast<Constant>(getArgOperand(FromArg))->isOneValue();
}
unsigned getAlignment() const {
return cast<ConstantInt>(getArgOperand(AlignArg))->getZExtValue();
}

// Methods to support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const IntrinsicInst *I) {
return I->getIntrinsicID() == Intrinsic::coro_promise;
}
static inline bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
};

/// This represents the llvm.coro.suspend instruction.
class LLVM_LIBRARY_VISIBILITY CoroSuspendInst : public IntrinsicInst {
enum { SaveArg, FinalArg };
Expand Down
10 changes: 6 additions & 4 deletions lib/Transforms/Coroutines/CoroInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,19 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
enum {
ResumeField,
DestroyField,
PromiseField,
IndexField,
LastKnownField = IndexField
};

StructType *FrameTy;
Instruction *FramePtr;
BasicBlock* AllocaSpillBlock;
SwitchInst* ResumeSwitch;
BasicBlock *AllocaSpillBlock;
SwitchInst *ResumeSwitch;
AllocaInst *PromiseAlloca;
bool HasFinalSuspend;

IntegerType* getIndexType() const {
IntegerType *getIndexType() const {
assert(FrameTy && "frame type not assigned");
return cast<IntegerType>(FrameTy->getElementType(IndexField));
}
Expand All @@ -97,7 +99,7 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
void buildFrom(Function &F);
};

void buildCoroutineFrame(Function& F, Shape& Shape);
void buildCoroutineFrame(Function &F, Shape &Shape);

} // End namespace coro.
} // End namespace llvm
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/Coroutines/Coroutines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ static void clear(coro::Shape &Shape) {
Shape.FramePtr = nullptr;
Shape.AllocaSpillBlock = nullptr;
Shape.ResumeSwitch = nullptr;
Shape.PromiseAlloca = nullptr;
Shape.HasFinalSuspend = false;
}

Expand Down
71 changes: 71 additions & 0 deletions test/Transforms/Coroutines/ex4.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
; Fourth example from Doc/Coroutines.rst (coroutine promise)
; RUN: opt < %s -O2 -enable-coroutines -S | FileCheck %s

define i8* @f(i32 %n) {
entry:
%promise = alloca i32
%pv = bitcast i32* %promise to i8*
%id = call token @llvm.coro.id(i32 0, i8* %pv, i8* null, i8* null)
%need.dyn.alloc = call i1 @llvm.coro.alloc(token %id)
br i1 %need.dyn.alloc, label %dyn.alloc, label %coro.begin
dyn.alloc:
%size = call i32 @llvm.coro.size.i32()
%alloc = call i8* @malloc(i32 %size)
br label %coro.begin
coro.begin:
%phi = phi i8* [ null, %entry ], [ %alloc, %dyn.alloc ]
%hdl = call noalias i8* @llvm.coro.begin(token %id, i8* %phi)
br label %loop
loop:
%n.val = phi i32 [ %n, %coro.begin ], [ %inc, %loop ]
%inc = add nsw i32 %n.val, 1
store i32 %n.val, i32* %promise
%0 = call i8 @llvm.coro.suspend(token none, i1 false)
switch i8 %0, label %suspend [i8 0, label %loop
i8 1, label %cleanup]
cleanup:
%mem = call i8* @llvm.coro.free(token %id, i8* %hdl)
call void @free(i8* %mem)
br label %suspend
suspend:
call void @llvm.coro.end(i8* %hdl, i1 false)
ret i8* %hdl
}

; CHECK-LABEL: @main
define i32 @main() {
entry:
%hdl = call i8* @f(i32 4)
%promise.addr.raw = call i8* @llvm.coro.promise(i8* %hdl, i32 4, i1 false)
%promise.addr = bitcast i8* %promise.addr.raw to i32*
%val0 = load i32, i32* %promise.addr
call void @print(i32 %val0)
call void @llvm.coro.resume(i8* %hdl)
%val1 = load i32, i32* %promise.addr
call void @print(i32 %val1)
call void @llvm.coro.resume(i8* %hdl)
%val2 = load i32, i32* %promise.addr
call void @print(i32 %val2)
call void @llvm.coro.destroy(i8* %hdl)
ret i32 0
; CHECK: call void @print(i32 4)
; CHECK-NEXT: call void @print(i32 5)
; CHECK-NEXT: call void @print(i32 6)
; CHECK: ret i32 0
}

declare i8* @llvm.coro.promise(i8*, i32, i1)
declare i8* @malloc(i32)
declare void @free(i8*)
declare void @print(i32)

declare token @llvm.coro.id(i32, i8*, i8*, i8*)
declare i1 @llvm.coro.alloc(token)
declare i32 @llvm.coro.size.i32()
declare i8* @llvm.coro.begin(token, i8*)
declare i8 @llvm.coro.suspend(token, i1)
declare i8* @llvm.coro.free(token, i8*)
declare void @llvm.coro.end(i8*, i1)

declare void @llvm.coro.resume(i8*)
declare void @llvm.coro.destroy(i8*)

0 comments on commit b6a1398

Please sign in to comment.