Skip to content

Commit

Permalink
[DirectX][NFC] Leverage LLVM and DirectX intrinsic description in DXI…
Browse files Browse the repository at this point in the history
…L Op records (#83193)

* Leverage TableGen record descriptions of LLVM or DirectX intrinsics
that can be directly mapped in DXIL Ops TableGen description. As a
result, such DXIL Ops can be succinctly described without duplication.
DXILEmitter backend can derive the properties of DXIL Ops accordingly.
* Ensured that corresponding lit tests pass.
  • Loading branch information
bharadwajy authored Feb 29, 2024
1 parent 6f7d824 commit b1c8b9f
Show file tree
Hide file tree
Showing 5 changed files with 413 additions and 373 deletions.
347 changes: 216 additions & 131 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,139 +12,224 @@
//===----------------------------------------------------------------------===//

include "llvm/IR/Intrinsics.td"
include "llvm/IR/Attributes.td"

// Abstract representation of the class a DXIL Operation belongs to.
class DXILOpClass<string name> {
string Name = name;
class DXILOpClass;

// Following is a set of DXIL Operation classes whose names appear to be
// arbitrary, yet need to be a substring of the function name used during
// lowering to DXIL Operation calls. These class name strings are specified
// as the third argument of add_dixil_op in utils/hct/hctdb.py and case converted
// in utils/hct/hctdb_instrhelp.py of DirectXShaderCompiler repo. The function
// name has the format "dx.op.<class-name>.<return-type>".

defset list<DXILOpClass> OpClasses = {
def acceptHitAndEndSearch : DXILOpClass;
def allocateNodeOutputRecords : DXILOpClass;
def allocateRayQuery : DXILOpClass;
def annotateHandle : DXILOpClass;
def annotateNodeHandle : DXILOpClass;
def annotateNodeRecordHandle : DXILOpClass;
def atomicBinOp : DXILOpClass;
def atomicCompareExchange : DXILOpClass;
def attributeAtVertex : DXILOpClass;
def barrier : DXILOpClass;
def barrierByMemoryHandle : DXILOpClass;
def barrierByMemoryType : DXILOpClass;
def barrierByNodeRecordHandle : DXILOpClass;
def binary : DXILOpClass;
def binaryWithCarryOrBorrow : DXILOpClass;
def binaryWithTwoOuts : DXILOpClass;
def bitcastF16toI16 : DXILOpClass;
def bitcastF32toI32 : DXILOpClass;
def bitcastF64toI64 : DXILOpClass;
def bitcastI16toF16 : DXILOpClass;
def bitcastI32toF32 : DXILOpClass;
def bitcastI64toF64 : DXILOpClass;
def bufferLoad : DXILOpClass;
def bufferStore : DXILOpClass;
def bufferUpdateCounter : DXILOpClass;
def calculateLOD : DXILOpClass;
def callShader : DXILOpClass;
def cbufferLoad : DXILOpClass;
def cbufferLoadLegacy : DXILOpClass;
def checkAccessFullyMapped : DXILOpClass;
def coverage : DXILOpClass;
def createHandle : DXILOpClass;
def createHandleForLib : DXILOpClass;
def createHandleFromBinding : DXILOpClass;
def createHandleFromHeap : DXILOpClass;
def createNodeInputRecordHandle : DXILOpClass;
def createNodeOutputHandle : DXILOpClass;
def cutStream : DXILOpClass;
def cycleCounterLegacy : DXILOpClass;
def discard : DXILOpClass;
def dispatchMesh : DXILOpClass;
def dispatchRaysDimensions : DXILOpClass;
def dispatchRaysIndex : DXILOpClass;
def domainLocation : DXILOpClass;
def dot2 : DXILOpClass;
def dot2AddHalf : DXILOpClass;
def dot3 : DXILOpClass;
def dot4 : DXILOpClass;
def dot4AddPacked : DXILOpClass;
def emitIndices : DXILOpClass;
def emitStream : DXILOpClass;
def emitThenCutStream : DXILOpClass;
def evalCentroid : DXILOpClass;
def evalSampleIndex : DXILOpClass;
def evalSnapped : DXILOpClass;
def finishedCrossGroupSharing : DXILOpClass;
def flattenedThreadIdInGroup : DXILOpClass;
def geometryIndex : DXILOpClass;
def getDimensions : DXILOpClass;
def getInputRecordCount : DXILOpClass;
def getMeshPayload : DXILOpClass;
def getNodeRecordPtr : DXILOpClass;
def getRemainingRecursionLevels : DXILOpClass;
def groupId : DXILOpClass;
def gsInstanceID : DXILOpClass;
def hitKind : DXILOpClass;
def ignoreHit : DXILOpClass;
def incrementOutputCount : DXILOpClass;
def indexNodeHandle : DXILOpClass;
def innerCoverage : DXILOpClass;
def instanceID : DXILOpClass;
def instanceIndex : DXILOpClass;
def isHelperLane : DXILOpClass;
def isSpecialFloat : DXILOpClass;
def legacyDoubleToFloat : DXILOpClass;
def legacyDoubleToSInt32 : DXILOpClass;
def legacyDoubleToUInt32 : DXILOpClass;
def legacyF16ToF32 : DXILOpClass;
def legacyF32ToF16 : DXILOpClass;
def loadInput : DXILOpClass;
def loadOutputControlPoint : DXILOpClass;
def loadPatchConstant : DXILOpClass;
def makeDouble : DXILOpClass;
def minPrecXRegLoad : DXILOpClass;
def minPrecXRegStore : DXILOpClass;
def nodeOutputIsValid : DXILOpClass;
def objectRayDirection : DXILOpClass;
def objectRayOrigin : DXILOpClass;
def objectToWorld : DXILOpClass;
def outputComplete : DXILOpClass;
def outputControlPointID : DXILOpClass;
def pack4x8 : DXILOpClass;
def primitiveID : DXILOpClass;
def primitiveIndex : DXILOpClass;
def quadOp : DXILOpClass;
def quadReadLaneAt : DXILOpClass;
def quadVote : DXILOpClass;
def quaternary : DXILOpClass;
def rawBufferLoad : DXILOpClass;
def rawBufferStore : DXILOpClass;
def rayFlags : DXILOpClass;
def rayQuery_Abort : DXILOpClass;
def rayQuery_CommitNonOpaqueTriangleHit : DXILOpClass;
def rayQuery_CommitProceduralPrimitiveHit : DXILOpClass;
def rayQuery_Proceed : DXILOpClass;
def rayQuery_StateMatrix : DXILOpClass;
def rayQuery_StateScalar : DXILOpClass;
def rayQuery_StateVector : DXILOpClass;
def rayQuery_TraceRayInline : DXILOpClass;
def rayTCurrent : DXILOpClass;
def rayTMin : DXILOpClass;
def renderTargetGetSampleCount : DXILOpClass;
def renderTargetGetSamplePosition : DXILOpClass;
def reportHit : DXILOpClass;
def sample : DXILOpClass;
def sampleBias : DXILOpClass;
def sampleCmp : DXILOpClass;
def sampleCmpBias : DXILOpClass;
def sampleCmpGrad : DXILOpClass;
def sampleCmpLevel : DXILOpClass;
def sampleCmpLevelZero : DXILOpClass;
def sampleGrad : DXILOpClass;
def sampleIndex : DXILOpClass;
def sampleLevel : DXILOpClass;
def setMeshOutputCounts : DXILOpClass;
def splitDouble : DXILOpClass;
def startInstanceLocation : DXILOpClass;
def startVertexLocation : DXILOpClass;
def storeOutput : DXILOpClass;
def storePatchConstant : DXILOpClass;
def storePrimitiveOutput : DXILOpClass;
def storeVertexOutput : DXILOpClass;
def tempRegLoad : DXILOpClass;
def tempRegStore : DXILOpClass;
def tertiary : DXILOpClass;
def texture2DMSGetSamplePosition : DXILOpClass;
def textureGather : DXILOpClass;
def textureGatherCmp : DXILOpClass;
def textureGatherRaw : DXILOpClass;
def textureLoad : DXILOpClass;
def textureStore : DXILOpClass;
def textureStoreSample : DXILOpClass;
def threadId : DXILOpClass;
def threadIdInGroup : DXILOpClass;
def traceRay : DXILOpClass;
def unary : DXILOpClass;
def unaryBits : DXILOpClass;
def unpack4x8 : DXILOpClass;
def viewID : DXILOpClass;
def waveActiveAllEqual : DXILOpClass;
def waveActiveBallot : DXILOpClass;
def waveActiveBit : DXILOpClass;
def waveActiveOp : DXILOpClass;
def waveAllOp : DXILOpClass;
def waveAllTrue : DXILOpClass;
def waveAnyTrue : DXILOpClass;
def waveGetLaneCount : DXILOpClass;
def waveGetLaneIndex : DXILOpClass;
def waveIsFirstLane : DXILOpClass;
def waveMatch : DXILOpClass;
def waveMatrix_Accumulate : DXILOpClass;
def waveMatrix_Annotate : DXILOpClass;
def waveMatrix_Depth : DXILOpClass;
def waveMatrix_Fill : DXILOpClass;
def waveMatrix_LoadGroupShared : DXILOpClass;
def waveMatrix_LoadRawBuf : DXILOpClass;
def waveMatrix_Multiply : DXILOpClass;
def waveMatrix_ScalarOp : DXILOpClass;
def waveMatrix_StoreGroupShared : DXILOpClass;
def waveMatrix_StoreRawBuf : DXILOpClass;
def waveMultiPrefixBitCount : DXILOpClass;
def waveMultiPrefixOp : DXILOpClass;
def wavePrefixOp : DXILOpClass;
def waveReadLaneAt : DXILOpClass;
def waveReadLaneFirst : DXILOpClass;
def worldRayDirection : DXILOpClass;
def worldRayOrigin : DXILOpClass;
def worldToObject : DXILOpClass;
def writeSamplerFeedback : DXILOpClass;
def writeSamplerFeedbackBias : DXILOpClass;
def writeSamplerFeedbackGrad : DXILOpClass;
def writeSamplerFeedbackLevel: DXILOpClass;
}

// Abstract representation of the category a DXIL Operation belongs to
class DXILOpCategory<string name> {
string Name = name;
// Abstraction DXIL Operation to LLVM intrinsic
class DXILOpMapping<int opCode, DXILOpClass opClass, Intrinsic intrinsic, string doc> {
int OpCode = opCode; // Opcode corresponding to DXIL Operation
DXILOpClass OpClass = opClass; // Class of DXIL Operation.
Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
string Doc = doc; // to a short description of the operation
}

def UnaryClass : DXILOpClass<"Unary">;
def BinaryClass : DXILOpClass<"Binary">;
def FlattenedThreadIdInGroupClass : DXILOpClass<"FlattenedThreadIdInGroup">;
def ThreadIdInGroupClass : DXILOpClass<"ThreadIdInGroup">;
def ThreadIdClass : DXILOpClass<"ThreadId">;
def GroupIdClass : DXILOpClass<"GroupId">;

def BinaryUintCategory : DXILOpCategory<"Binary uint">;
def UnaryFloatCategory : DXILOpCategory<"Unary float">;
def ComputeIDCategory : DXILOpCategory<"Compute/Mesh/Amplification shader">;

// Represent as any pointer type with an option to change to a qualified pointer
// type with address space specified.
def dxil_handle_ty : LLVMAnyPointerType;
def dxil_cbuffer_ty : LLVMAnyPointerType;
def dxil_resource_ty : LLVMAnyPointerType;

// The parameter description for a DXIL operation
class DXILOpParameter<int pos, LLVMType type, string name, string doc,
bit isConstant = 0, string enumName = "",
int maxValue = 0> {
int Pos = pos; // Position in parameter list
LLVMType ParamType = type; // Parameter type
string Name = name; // Short, unique parameter name
string Doc = doc; // Description of this parameter
bit IsConstant = isConstant; // Whether this parameter requires a constant value in the IR
string EnumName = enumName; // Name of the enum type, if applicable
int MaxValue = maxValue; // Maximum value for this parameter, if applicable
}

// A representation for a DXIL operation
class DXILOperationDesc {
string OpName = ""; // Name of DXIL operation
int OpCode = 0; // Unique non-negative integer associated with the operation
DXILOpClass OpClass; // Class of the operation
DXILOpCategory OpCategory; // Category of the operation
string Doc = ""; // Description of the operation
list<DXILOpParameter> Params = []; // Parameter list of the operation
list<LLVMType> OverloadTypes = []; // Overload types, if applicable
EnumAttr Attribute; // Operation Attribute. Leverage attributes defined in Attributes.td
// ReadNone - operation does not access memory.
// ReadOnly - only reads from memory.
// "ReadMemory" - reads memory
bit IsDerivative = 0; // Whether this is some kind of derivative
bit IsGradient = 0; // Whether this requires a gradient calculation
bit IsFeedback = 0; // Whether this is a sampler feedback operation
bit IsWave = 0; // Whether this requires in-wave, cross-lane functionality
bit NeedsUniformInputs = 0; // Whether this operation requires that all
// of its inputs are uniform across the wave
// Group DXIL operation for stats - e.g., to accumulate the number of atomic/float/uint/int/...
// operations used in the program.
list<string> StatsGroup = [];
}

class DXILOperation<string name, int opCode, DXILOpClass opClass, DXILOpCategory opCategory, string doc,
list<LLVMType> oloadTypes, EnumAttr attrs, list<DXILOpParameter> params,
list<string> statsGroup = []> : DXILOperationDesc {
let OpName = name;
let OpCode = opCode;
let Doc = doc;
let Params = params;
let OpClass = opClass;
let OpCategory = opCategory;
let OverloadTypes = oloadTypes;
let Attribute = attrs;
let StatsGroup = statsGroup;
}

// LLVM intrinsic that DXIL operation maps to.
class LLVMIntrinsic<Intrinsic llvm_intrinsic_> { Intrinsic llvm_intrinsic = llvm_intrinsic_; }

def Sin : DXILOperation<"Sin", 13, UnaryClass, UnaryFloatCategory, "returns sine(theta) for theta in radians.",
[llvm_half_ty, llvm_float_ty], ReadNone,
[
DXILOpParameter<0, llvm_anyfloat_ty, "", "operation result">,
DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
DXILOpParameter<2, llvm_anyfloat_ty, "value", "input value">
],
["floats"]>,
LLVMIntrinsic<int_sin>;

def UMax : DXILOperation< "UMax", 39, BinaryClass, BinaryUintCategory, "unsigned integer maximum. UMax(a,b) = a > b ? a : b",
[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty], ReadNone,
[
DXILOpParameter<0, llvm_anyint_ty, "", "operation result">,
DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
DXILOpParameter<2, llvm_anyint_ty, "a", "input value">,
DXILOpParameter<3, llvm_anyint_ty, "b", "input value">
],
["uints"]>,
LLVMIntrinsic<int_umax>;

def ThreadId : DXILOperation< "ThreadId", 93, ThreadIdClass, ComputeIDCategory, "reads the thread ID", [llvm_i32_ty], ReadNone,
[
DXILOpParameter<0, llvm_i32_ty, "", "thread ID component">,
DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
DXILOpParameter<2, llvm_i32_ty, "component", "component to read (x,y,z)">
]>,
LLVMIntrinsic<int_dx_thread_id>;

def GroupId : DXILOperation< "GroupId", 94, GroupIdClass, ComputeIDCategory, "reads the group ID (SV_GroupID)", [llvm_i32_ty], ReadNone,
[
DXILOpParameter<0, llvm_i32_ty, "", "group ID component">,
DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
DXILOpParameter<2, llvm_i32_ty, "component", "component to read">
]>,
LLVMIntrinsic<int_dx_group_id>;

def ThreadIdInGroup : DXILOperation< "ThreadIdInGroup", 95, ThreadIdInGroupClass, ComputeIDCategory,
"reads the thread ID within the group (SV_GroupThreadID)", [llvm_i32_ty], ReadNone,
[
DXILOpParameter<0, llvm_i32_ty, "", "thread ID in group component">,
DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
DXILOpParameter<2, llvm_i32_ty, "component", "component to read (x,y,z)">
]>,
LLVMIntrinsic<int_dx_thread_id_in_group>;

def FlattenedThreadIdInGroup : DXILOperation< "FlattenedThreadIdInGroup", 96, FlattenedThreadIdInGroupClass, ComputeIDCategory,
"provides a flattened index for a given thread within a given group (SV_GroupIndex)", [llvm_i32_ty], ReadNone,
[
DXILOpParameter<0, llvm_i32_ty, "", "result">,
DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">
]>,
LLVMIntrinsic<int_dx_flattened_thread_id_in_group>;
// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
def Sin : DXILOpMapping<13, unary, int_sin,
"Returns sine(theta) for theta in radians.">;
def UMax : DXILOpMapping<39, binary, int_umax,
"Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
"Reads the thread ID">;
def GroupId : DXILOpMapping<94, groupId, int_dx_group_id,
"Reads the group ID (SV_GroupID)">;
def ThreadIdInGroup : DXILOpMapping<95, threadIdInGroup,
int_dx_thread_id_in_group,
"Reads the thread ID within the group "
"(SV_GroupThreadID)">;
def FlattenedThreadIdInGroup : DXILOpMapping<96, flattenedThreadIdInGroup,
int_dx_flattened_thread_id_in_group,
"Provides a flattened index for a "
"given thread within a given "
"group (SV_GroupIndex)">;
Loading

0 comments on commit b1c8b9f

Please sign in to comment.