Skip to content

Commit

Permalink
Collection of changes to fix clang build. (NVIDIA#1200)
Browse files Browse the repository at this point in the history
* Remove unused variables

* Qualify calls to make_fragment_? from templated base class.

Fixes clang build error.

* Add missing `#include <cstdio>`

* Various changes to fix clang compile errors.

* More changes to fix clang build.

Remaining issues:

- `params` initializer of `CollectiveEpilogue`.
- `ops` initializer of `Sm90VisitorImplBase`.
- `__usAtomicCAS` needs to be added to clang upstream.

* Fix remaining clang build issues.

* Qualify `cute::rank()` calls.

* Qualify some more calls that are otherwise ambiguous between `cute` and `std` namespace.

* Double-escape special registers in inline asm.

* small change

---------

Co-authored-by: Haicheng Wu <[email protected]>
  • Loading branch information
chsigg and hwu36 authored Dec 8, 2023
1 parent f4a0216 commit e1483d5
Show file tree
Hide file tree
Showing 46 changed files with 308 additions and 273 deletions.
18 changes: 9 additions & 9 deletions examples/51_hopper_gett/51_hopper_gett.cu
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ main(int argc, char const* argv[]) {
using ElementEpilogue = float;

// The following constexpr values set the max number of modes in each MNKL mode
constexpr int MaxRank_M = rank(RowModeStridesA{}); // Max row modes
constexpr int MaxRank_N = rank(ColModeStridesB{}); // Max column modes
constexpr int MaxRank_K = rank(RedModeStridesA{}); // Max contraction modes
constexpr int MaxRank_L = rank(BatModeStridesA{}); // Max batch modes
static_assert(rank(RowModeStridesA{}) == rank(RowModeStridesC{}));
static_assert(rank(ColModeStridesB{}) == rank(RowModeStridesC{}));
static_assert(rank(RedModeStridesA{}) == rank(RedModeStridesB{}));
static_assert(rank(BatModeStridesA{}) == rank(BatModeStridesC{}));
static_assert(rank(BatModeStridesB{}) == rank(BatModeStridesC{}));
constexpr int MaxRank_M = cute::rank(RowModeStridesA{}); // Max row modes
constexpr int MaxRank_N = cute::rank(ColModeStridesB{}); // Max column modes
constexpr int MaxRank_K = cute::rank(RedModeStridesA{}); // Max contraction modes
constexpr int MaxRank_L = cute::rank(BatModeStridesA{}); // Max batch modes
static_assert(cute::rank(RowModeStridesA{}) == cute::rank(RowModeStridesC{}));
static_assert(cute::rank(ColModeStridesB{}) == cute::rank(RowModeStridesC{}));
static_assert(cute::rank(RedModeStridesA{}) == cute::rank(RedModeStridesB{}));
static_assert(cute::rank(BatModeStridesA{}) == cute::rank(BatModeStridesC{}));
static_assert(cute::rank(BatModeStridesB{}) == cute::rank(BatModeStridesC{}));

// Parse command line to get modes, extents, and strides
cutlass::GettCommandLine cmd;
Expand Down
12 changes: 6 additions & 6 deletions examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class GemmGather
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
Expand Down Expand Up @@ -180,7 +180,7 @@ class GemmGather
bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
Expand Down Expand Up @@ -288,10 +288,10 @@ class GemmGather
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
// Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
Expand Down
10 changes: 5 additions & 5 deletions examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class EpilogueGatherScatter {
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;

static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");

struct SharedStorage { };

Expand Down Expand Up @@ -151,10 +151,10 @@ class EpilogueGatherScatter {
using namespace cute;
using X = Underscore;

static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");

(void) smem_buf;
ThreadEpilogueOp epilogue_op{params.thread_params};
Expand Down
6 changes: 3 additions & 3 deletions examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,14 @@ template<class ... Shapes>
auto
select_mode_shape(Shapes const & ... shapes) {
auto permuted_shapes = filter_tuple(cute::make_tuple(shapes...), [](auto shape) {
if constexpr (rank(shape) > 1) {
if constexpr (cute::rank(shape) > 1) {
return cute::make_tuple(shape);
}
else {
return cute::make_tuple();
}
});
if constexpr (rank(permuted_shapes) == 0) {
if constexpr (cute::rank(permuted_shapes) == 0) {
return get<0>(cute::make_tuple(shapes...));
}
else {
Expand Down Expand Up @@ -251,7 +251,7 @@ auto
select_tile_shape(TileSize size, Shape const& shape)
{
static_assert(is_static<TileSize>::value, "Tile size must be static");
if constexpr (rank(Shape{}) == 0) {
if constexpr (cute::rank(Shape{}) == 0) {
return cute::make_tuple(size);
}
else {
Expand Down
4 changes: 2 additions & 2 deletions examples/53_hopper_gemm_permute/permute_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ reshape(Shape const& shape, TargetShape const& target_shape)
template<class Permute, bool Transpose, class Shape, class Stride>
constexpr auto
make_permute_layout(Layout<Shape,Stride> const& layout) {
static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported");
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
if constexpr (Transpose) {
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
Expand Down Expand Up @@ -135,7 +135,7 @@ using inverse_t = decltype(inverse(T{}));
template<class Permute, bool Transpose, class Shape, class Stride>
constexpr auto
make_original_layout(Layout<Shape,Stride> const& layout) {
static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported");
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
if constexpr (Transpose) {
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
Expand Down
26 changes: 13 additions & 13 deletions include/cute/arch/cluster_sm90.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ CUTE_DEVICE dim3 cluster_grid_dims()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %nclusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : );
asm volatile("mov.u32 %0, %%nclusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %%nclusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) : );
return {x, y, z};
#elif defined(__CUDA_ARCH__)
// MSVC requires protecting use of gridDim with __CUDA_ARCH__.
Expand All @@ -105,9 +105,9 @@ CUTE_DEVICE dim3 cluster_id_in_grid()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : );
asm volatile("mov.u32 %0, %%clusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %%clusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) : );
return {x, y, z};
#elif defined(__CUDA_ARCH__)
// MSVC requires protecting use of blockIdx with __CUDA_ARCH__.
Expand All @@ -124,9 +124,9 @@ CUTE_DEVICE dim3 block_id_in_cluster()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %cluster_ctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %cluster_ctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %cluster_ctaid.z;\n" : "=r"(z) : );
asm volatile("mov.u32 %0, %%cluster_ctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %%cluster_ctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) : );
return {x, y, z};
#else
return {0,0,0};
Expand All @@ -138,9 +138,9 @@ CUTE_DEVICE dim3 cluster_shape()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %cluster_nctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %cluster_nctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %cluster_nctaid.z;\n" : "=r"(z) : );
asm volatile("mov.u32 %0, %%cluster_nctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %%cluster_nctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) : );
return {x, y, z};
#else
return {1,1,1};
Expand All @@ -152,7 +152,7 @@ CUTLASS_DEVICE uint32_t block_rank_in_cluster()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t rank;
asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(rank) :);
asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :);
return rank;
#else
return 0;
Expand Down
7 changes: 4 additions & 3 deletions include/cute/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
**************************************************************************************************/
#pragma once

#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__
# define CUTE_DEVICE __forceinline__ __device__
# define CUTE_HOST __forceinline__ __host__
Expand All @@ -46,10 +46,11 @@
# define CUTE_HOST_RTC CUTE_HOST
#endif

#if !defined(__CUDACC_RTC__) && (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA))
#if !defined(__CUDACC_RTC__) && !defined(__clang__) && \
(defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA))
# define CUTE_UNROLL #pragma unroll
# define CUTE_NO_UNROLL #pragma unroll 1
#elif defined(__CUDACC_RTC__)
#elif defined(__CUDACC_RTC__) || defined(__clang__)
# define CUTE_UNROLL _Pragma("unroll")
# define CUTE_NO_UNROLL _Pragma("unroll 1")
#else
Expand Down
1 change: 1 addition & 0 deletions include/cutlass/cluster_launch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#pragma once

#include <cstdio>
#include <cuda_runtime_api.h>
#include "cutlass/cutlass.h"
#include "cutlass/trace.h"
Expand Down
6 changes: 3 additions & 3 deletions include/cutlass/epilogue/collective/builders/sm90_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ CollectiveBuilder<
cute::is_base_of_v<TmaWarpSpecializedCooperativeElementwiseBase, Schedule> >> {
private:
using FusionOp =
fusion::LinCombEltAct<Schedule::ActivationFunctor, ElementD, ElementCompute, ElementCompute, Schedule::Round>;
fusion::LinCombEltAct<Schedule::template ActivationFunctor, ElementD, ElementCompute, ElementCompute, Schedule::Round>;
using ImplSchedule =
cute::conditional_t<cute::is_base_of_v<TmaWarpSpecializedElementwiseBase, Schedule>,
TmaWarpSpecialized, TmaWarpSpecializedCooperative>;
Expand Down Expand Up @@ -676,15 +676,15 @@ private:
using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator<
GmemStrideTypeAux, typename Schedule::ElementT>());
using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux<
GmemLayoutTagD, Schedule::ActivationFunctor, ElementD, ElementCompute,
GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute,
typename Schedule::ElementT, typename Schedule::ElementBias, ElementCompute
>;
using FusionCallbacksAux = fusion::FusionCallbacks<
DispatchPolicy, FusionOperationAux, TileShape_MNK, EpilogueTile_MN, SmemLayoutAtomAux, SmemCopyOpAux
>;

using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct<
Schedule::ActivationFunctor, ElementD, ElementCompute,
Schedule::template ActivationFunctor, ElementD, ElementCompute,
typename Schedule::ElementBias, ElementCompute
>;
using FusionCallbacksNoAux = fusion::FusionCallbacks<
Expand Down
10 changes: 5 additions & 5 deletions include/cutlass/epilogue/collective/default_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class DefaultEpilogue {
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;

static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");

struct SharedStorage { };

Expand Down Expand Up @@ -163,10 +163,10 @@ class DefaultEpilogue {
using namespace cute;
using X = Underscore;

static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");

// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);
Expand Down
4 changes: 2 additions & 2 deletions include/cutlass/epilogue/collective/detail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,12 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp {
int thread_idx,
TensorStorage& shared_tensors)
{
constexpr int BLK_M_RANK = rank<0>(tile_shape_MNK);
constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK);
auto m_max_coord = unwrap(cute::transform(make_seq<BLK_M_RANK>{}, [&](auto i) {
return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl);
}));

constexpr int BLK_N_RANK = rank<1>(tile_shape_MNK);
constexpr int BLK_N_RANK = cute::rank<1>(tile_shape_MNK);
auto n_max_coord = unwrap(cute::transform(make_seq<BLK_N_RANK>{}, [&](auto i) {
return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl);
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ class EpilogueTensorBroadcast {
using StrideD = StrideD_;
using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor;

static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");

static constexpr int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
Expand Down Expand Up @@ -182,10 +182,10 @@ class EpilogueTensorBroadcast {
using namespace cute;
using X = Underscore;

static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4");
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4");

// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);
Expand Down
10 changes: 5 additions & 5 deletions include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ class Epilogue {
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;

static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");

struct SharedStorage
{
Expand Down Expand Up @@ -172,10 +172,10 @@ class Epilogue {
using namespace cute;
using X = Underscore;

static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");

// synchronizing function for smem reads/writes
#if CUDA_BARRIER_ENABLED
Expand Down
Loading

0 comments on commit e1483d5

Please sign in to comment.