Skip to content

Commit

Permalink
[AArch64][SME] Allow memory operations lowering to custom SME functio…
Browse files Browse the repository at this point in the history
…ns. (llvm#79263)

This change allows to lower memcpy, memset, memmove to custom SME
version provided by LibRT.
  • Loading branch information
dtemirbulatov authored Apr 9, 2024
1 parent 5601e35 commit 528943f
Show file tree
Hide file tree
Showing 4 changed files with 380 additions and 4 deletions.
87 changes: 83 additions & 4 deletions llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ using namespace llvm;

#define DEBUG_TYPE "aarch64-selectiondag-info"

static cl::opt<bool>
LowerToSMERoutines("aarch64-lower-to-sme-routines", cl::Hidden,
cl::desc("Enable AArch64 SME memory operations "
"to lower to librt functions"),
cl::init(true));

SDValue AArch64SelectionDAGInfo::EmitMOPS(AArch64ISD::NodeType SDOpcode,
SelectionDAG &DAG, const SDLoc &DL,
SDValue Chain, SDValue Dst,
Expand Down Expand Up @@ -76,15 +82,79 @@ SDValue AArch64SelectionDAGInfo::EmitMOPS(AArch64ISD::NodeType SDOpcode,
}
}

SDValue AArch64SelectionDAGInfo::EmitStreamingCompatibleMemLibCall(
SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
SDValue Size, RTLIB::Libcall LC) const {
const AArch64Subtarget &STI =
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
const AArch64TargetLowering *TLI = STI.getTargetLowering();
SDValue Symbol;
TargetLowering::ArgListEntry DstEntry;
DstEntry.Ty = PointerType::getUnqual(*DAG.getContext());
DstEntry.Node = Dst;
TargetLowering::ArgListTy Args;
Args.push_back(DstEntry);
EVT PointerVT = TLI->getPointerTy(DAG.getDataLayout());

switch (LC) {
case RTLIB::MEMCPY: {
TargetLowering::ArgListEntry Entry;
Entry.Ty = PointerType::getUnqual(*DAG.getContext());
Symbol = DAG.getExternalSymbol("__arm_sc_memcpy", PointerVT);
Entry.Node = Src;
Args.push_back(Entry);
break;
}
case RTLIB::MEMMOVE: {
TargetLowering::ArgListEntry Entry;
Entry.Ty = PointerType::getUnqual(*DAG.getContext());
Symbol = DAG.getExternalSymbol("__arm_sc_memmove", PointerVT);
Entry.Node = Src;
Args.push_back(Entry);
break;
}
case RTLIB::MEMSET: {
TargetLowering::ArgListEntry Entry;
Entry.Ty = Type::getInt32Ty(*DAG.getContext());
Symbol = DAG.getExternalSymbol("__arm_sc_memset", PointerVT);
Src = DAG.getZExtOrTrunc(Src, DL, MVT::i32);
Entry.Node = Src;
Args.push_back(Entry);
break;
}
default:
return SDValue();
}

TargetLowering::ArgListEntry SizeEntry;
SizeEntry.Node = Size;
SizeEntry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
Args.push_back(SizeEntry);
assert(Symbol->getOpcode() == ISD::ExternalSymbol &&
"Function name is not set");

TargetLowering::CallLoweringInfo CLI(DAG);
PointerType *RetTy = PointerType::getUnqual(*DAG.getContext());
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
TLI->getLibcallCallingConv(LC), RetTy, Symbol, std::move(Args));
return TLI->LowerCallTo(CLI).second;
}

SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
const AArch64Subtarget &STI =
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();

if (STI.hasMOPS())
return EmitMOPS(AArch64ISD::MOPS_MEMCOPY, DAG, DL, Chain, Dst, Src, Size,
Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);

SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, DL, Chain, Dst, Src, Size,
RTLIB::MEMCPY);
return SDValue();
}

Expand All @@ -95,10 +165,14 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
const AArch64Subtarget &STI =
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();

if (STI.hasMOPS()) {
if (STI.hasMOPS())
return EmitMOPS(AArch64ISD::MOPS_MEMSET, DAG, dl, Chain, Dst, Src, Size,
Alignment, isVolatile, DstPtrInfo, MachinePointerInfo{});
}

SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
RTLIB::MEMSET);
return SDValue();
}

Expand All @@ -108,10 +182,15 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
const AArch64Subtarget &STI =
DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
if (STI.hasMOPS()) {

if (STI.hasMOPS())
return EmitMOPS(AArch64ISD::MOPS_MEMMOVE, DAG, dl, Chain, Dst, Src, Size,
Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
}

SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
RTLIB::MEMMOVE);
return SDValue();
}

Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class AArch64SelectionDAGInfo : public SelectionDAGTargetInfo {
SDValue Chain, SDValue Op1, SDValue Op2,
MachinePointerInfo DstPtrInfo,
bool ZeroData) const override;

SDValue EmitStreamingCompatibleMemLibCall(SelectionDAG &DAG, const SDLoc &DL,
SDValue Chain, SDValue Dst,
SDValue Src, SDValue Size,
RTLIB::Libcall LC) const;
};
}

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
if (FuncName == "__arm_tpidr2_restore")
Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
SMEAttrs::SME_ABI_Routine;
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
Bitmask |= SMEAttrs::SM_Compatible;
}

SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Expand Down
Loading

0 comments on commit 528943f

Please sign in to comment.