Skip to content

Commit

Permalink
streamk fix (#830)
Browse files Browse the repository at this point in the history
Co-authored-by: Haicheng Wu <[email protected]>
  • Loading branch information
hwu36 and hwu36 authored Feb 20, 2023
1 parent d8359c8 commit 91b8de8
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 141 deletions.
195 changes: 63 additions & 132 deletions include/cutlass/gemm/kernel/gemm_universal_streamk.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,6 @@ struct GemmUniversalStreamk {

ThreadblockSwizzle block_mapping;

bool quick_dp;

void *barrier_workspace;
void *partials_workspace;

Expand Down Expand Up @@ -367,13 +365,6 @@ struct GemmUniversalStreamk {
sm_occupancy,
device_sms,
avail_sms);

quick_dp =
(block_mapping.sk_waves == 0) &&
(mode == GemmUniversalMode::kGemm) &&
!block_mapping.cohort_raster &&
!EpilogueOutputOp(output_op).is_source_needed();

}


Expand Down Expand Up @@ -874,7 +865,7 @@ struct GemmUniversalStreamk {
threadblock_item_begin);

// Execute the epilogue operator to update the destination tensor.
epilogue.unified(
epilogue(
EpilogueOutputOp(params.output_op),
iterator_D,
accumulator_tile,
Expand Down Expand Up @@ -961,13 +952,14 @@ struct GemmUniversalStreamk {
AccumulatorTile accumulator_tile;
accumulator_tile.clear();

// Perform this tile's range of multiply-accumulate (MAC) iterations
// Initialize MMA abstraction
Mma mma(
shared_storage.main_loop,
thread_idx,
warp_idx,
lane_idx);

// Perform this tile's range of multiply-accumulate (MAC) iterations
mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);

if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) ||
Expand Down Expand Up @@ -1020,29 +1012,27 @@ struct GemmUniversalStreamk {
void gemm()
{
// Initialize block's iteration range
int tile_idx, block_iter_begin, block_iters_remaining;
int tile_idx = 0;
int block_iter_begin = 0;
int block_iters_remaining = 0;

int block_idx = params.block_mapping.get_block_idx();

int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region();
int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms;
int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks;
int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks;

int block_idx = params.block_mapping.get_block_idx();
if (block_idx < sk_padding_start_block_idx)
{
// This is a SK block
int block_iter_end;
params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end);
block_iters_remaining = block_iter_end - block_iter_begin;
// Initialize tile work descriptor
TileWorkDesc tile_work;

tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);
}
else if (block_idx < dp_start_block_idx)
{
// This is a filler block
return;
}
else if (block_idx < reduce_start_block_idx)
bool dp_block = (block_idx >= dp_start_block_idx) && (block_idx < reduce_start_block_idx);
bool sk_block = (block_idx < sk_padding_start_block_idx);
bool reduce_block = (block_idx >= reduce_start_block_idx) &&
(block_idx < grid_padding_start_block_idx) &&
(ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed);

if (dp_block)
{
// This is a DP block
int dp_block_idx = block_idx - dp_start_block_idx;
Expand All @@ -1058,132 +1048,83 @@ struct GemmUniversalStreamk {
tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms;
}

block_iter_begin = 0;
block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment;
}

else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) &&
(block_idx < grid_padding_start_block_idx))
init_dp_tile_work(tile_work, tile_idx);

// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)
if ((tile_idx < params.block_mapping.sk_tiles) ||
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) ||
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n()))
{
return;
}
}
else if (sk_block)
{
// This is a reduction threadblock
int reduce_block_idx = block_idx - reduce_start_block_idx;
separate_reduction(reduce_block_idx);
return;
// This is a SK block
int block_iter_end;
params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end);
block_iters_remaining = block_iter_end - block_iter_begin;

tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);

init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
}
else
{
// This is a filler block
if (reduce_block)
{
// This is a reduction threadblock
int reduce_block_idx = block_idx - reduce_start_block_idx;
separate_reduction(reduce_block_idx);
}

return;
}

// Perform this block's share of work for this tile
process_tile(
tile_work,
block_idx,
dp_start_block_idx,
block_iter_begin);

block_iters_remaining -= tile_work.k_iters_remaining;

// Iteration-processing loop body
CUTLASS_PRAGMA_NO_UNROLL
while (true)
while (block_iters_remaining != 0)
{
// Initialize tile work descriptor
TileWorkDesc tile_work;
if (block_idx >= dp_start_block_idx)
{
init_dp_tile_work(tile_work, tile_idx);

// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)
if ((tile_idx < params.block_mapping.sk_tiles) ||
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) ||
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n()))
{
break;
}
}
else
{
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
}

// Perform this block's share of work for this tile
process_tile(tile_work, block_idx, dp_start_block_idx, block_iter_begin);

// Update remaining work for this block
block_iters_remaining -= tile_work.k_iters_remaining;
if (block_iters_remaining == 0) {
// Done
break;
}

// Continue to next tile
__syncthreads();

if (block_idx >= dp_start_block_idx)
{
// DP block consume their tiles at stride
tile_idx += params.block_mapping.avail_sms;
init_dp_tile_work(tile_work, tile_idx);
}
else
{
// SK blocks consume their tiles in backwards order
tile_idx--;
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
}
}

}


/// Executes one DP-only GEMM
CUTLASS_DEVICE
void gemm_dp()
{
int block_idx = blockIdx.x;
int tile_idx = block_idx;

TileWorkDesc tile_work;
tile_work.tile_idx = tile_idx;
tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile();
tile_work.k_iters_remaining = params.block_mapping.iters_per_tile();
tile_work.k_begin = 0;
tile_work.k_end = params.block_mapping.problem_size.k();
tile_work.tiled_coord = params.block_mapping.get_tile_offset_row_major(tile_work.tile_idx);

// Initialize input iterators
typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode);
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode);

// Initialize accumulators
AccumulatorTile accumulator_tile;
accumulator_tile.clear();

// Perform this tile's range of multiply-accumulate (MAC) iterations
Mma mma(
shared_storage.main_loop,
thread_idx,
warp_idx,
lane_idx);

mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);

ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);

// Location of this tile in item-coords
MatrixCoord threadblock_item_begin(
tile_work.tiled_coord.m() * Mma::Shape::kM,
tile_work.tiled_coord.n() * Mma::Shape::kN
);
// Perform this block's share of work for this tile
process_tile(
tile_work,
block_idx,
dp_start_block_idx,
block_iter_begin);

// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
ptr_D,
params.block_mapping.problem_size.mn(),
thread_idx,
threadblock_item_begin);
block_iters_remaining -= tile_work.k_iters_remaining;
}

// Execute the epilogue operator to update the destination tensor.
epilogue(
EpilogueOutputOp(params.output_op),
iterator_D,
accumulator_tile);
}



public:

//
Expand Down Expand Up @@ -1224,16 +1165,6 @@ struct GemmUniversalStreamk {
CUTLASS_DEVICE
void operator()()
{
#if (__CUDACC_VER_MAJOR__ > 10)
if (params.quick_dp)
{
// Simple (low-bootstrap latency) GEMM code path for data-parallel only. (kBatched and kArray
// modes will only be launched using a data-parallel configurations)
gemm_dp();
return;
}
#endif

// Generic SK code path
gemm();

Expand Down
11 changes: 2 additions & 9 deletions include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h
Original file line number Diff line number Diff line change
Expand Up @@ -637,13 +637,6 @@ struct ThreadblockSwizzleStreamK {
// Device-side interface
//

/// Proves to the compiler that val is warp-uniform
CUTLASS_DEVICE
int uniform(int val) const
{
return __shfl_sync(0xffffffff, val, 0);
}

/// Obtains number of threadblocks per GEMM
CUTLASS_DEVICE
int device_num_blocks() const
Expand All @@ -656,7 +649,7 @@ struct ThreadblockSwizzleStreamK {
int get_sk_tile_idx(int iter) const
{
int tile_idx = div_mod_iters_per_tile.div(iter);
return uniform(tile_idx);
return tile_idx;
}

/// Obtains the batch index
Expand Down Expand Up @@ -734,7 +727,7 @@ struct ThreadblockSwizzleStreamK {
block_idx = (region * sk_blocks_per_region()) + block_in_region;
}

return uniform(block_idx);
return block_idx;
}


Expand Down

0 comments on commit 91b8de8

Please sign in to comment.