Skip to content

Commit

Permalink
Fixing Prefetch for Gemm and streamK (codeplaysoftware#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdi-goli authored Feb 17, 2025
1 parent 9a068b2 commit d016805
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,6 @@
#include "online_softmax.hpp"
#include "pvc_flash_attn_mma.hpp"

#ifdef __SYCL_DEVICE_ONLY__
#define SYCL_DEVICE_SPV_SPLIT_BARRIER(x) SYCL_EXTERNAL x
#else
#define SYCL_DEVICE_SPV_SPLIT_BARRIER(x) \
inline x { assert(false); }
#endif

SYCL_DEVICE_SPV_SPLIT_BARRIER(void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope,
int memory_semantics));
SYCL_DEVICE_SPV_SPLIT_BARRIER(void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope,
int memory_semantics));

#undef SYCL_DEVICE_SPV_SPLIT_BARRIER
namespace cutlass::gemm::kernel {

template <class ProblemShape, class CollectiveMainloop, class CollectiveEpilogue, class TileScheduler_ = void>
Expand All @@ -60,9 +47,6 @@ class GemmUniversalAttention;

template <class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
class GemmUniversalAttention {
// 3 is for subgroup, 2 is for workgroup
#define barrier_arrive(scope) __spirv_ControlBarrierArriveINTEL(scope, 0, 0);
#define barrier_wait(scope) __spirv_ControlBarrierWaitINTEL(scope, 0, 0);

public:
//
Expand Down
2 changes: 1 addition & 1 deletion examples/sycl/pvc/pvc_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ int main(int argc, const char** argv)
Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, _32>>;

constexpr int PipelineStages = 3;
constexpr int PipelineStages = 2;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;

Expand Down
2 changes: 1 addition & 1 deletion examples/sycl/pvc/pvc_gemm_streamk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ int main(int argc, const char** argv)
Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, _32>>;

constexpr int PipelineStages = 3;
constexpr int PipelineStages = 2;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;

Expand Down
18 changes: 18 additions & 0 deletions include/cute/arch/copy_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,26 @@ SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ulong8(
enum CacheControl cacheOpt));
#undef SYCL_DEVICE_BUILTIN

#ifdef __SYCL_DEVICE_ONLY__
SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics);
SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics);
#endif

namespace cute
{

// scope = 3 is for subgroup, scop = 2 is for workgroup
CUTE_HOST_DEVICE void barrier_arrive(int scope, int memory_scope = 0, int emory_semantics = 0) {
#ifdef __SYCL_DEVICE_ONLY__
__spirv_ControlBarrierArriveINTEL(scope, memory_scope, emory_semantics);
#endif
}
CUTE_HOST_DEVICE void barrier_wait(int scope, int memory_scope = 0, int emory_semantics = 0) {
#ifdef __SYCL_DEVICE_ONLY__
__spirv_ControlBarrierWaitINTEL(scope, memory_scope, emory_semantics);
#endif
}

template<class S, class D = S>
struct XE_ATOMIC {
using SRegisters = S[1];
Expand Down
13 changes: 4 additions & 9 deletions include/cutlass/arch/barrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@
#include <cutlass/arch/config.h>

#if defined(SYCL_INTEL_TARGET)
SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics);
SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics);

#define EXECUTION_SCOPE_WORK_GROUP 2
#define MEMORY_SCOPE_WORK_GROUP 2
#define MEMORY_SEMANTICS_RELAXED 0
#include <cute/arch/copy_xe.hpp>

#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12)
#define CUDA_BARRIER_ENABLED 1
Expand Down Expand Up @@ -285,8 +280,8 @@ class NamedBarrier {
CUTLASS_DEVICE
static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) {
#if defined(SYCL_INTEL_TARGET)
__spirv_ControlBarrierArriveINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED);
__spirv_ControlBarrierWaitINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED);
barrier_arrive(2,2,0);
barrier_wait(2,2,0);
#elif CUDA_BARRIER_ENABLED
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id);
Expand All @@ -308,7 +303,7 @@ class NamedBarrier {
CUTLASS_DEVICE
static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) {
#if defined(SYCL_INTEL_TARGET)
__spirv_ControlBarrierArriveINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED);
barrier_arrive(2,2,0);
#elif CUDA_BARRIER_ENABLED
cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id);
asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
Expand Down
Loading

0 comments on commit d016805

Please sign in to comment.