Skip to content

Commit

Permalink
fix gather example (#574)
Browse files Browse the repository at this point in the history
  • Loading branch information
shangz-ai authored Jul 19, 2022
1 parent 0b8cacd commit 5d05808
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions examples/36_gather_scatter_fusion/gather_scatter_fusion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@
// for (int j = 0; j < options.index_size; ++j) {
// int b_c_d_col = tensor_indices.at({j, 0});
//
// for (int k = 0; k < options.index_size; ++k) {
// int a_col = tensor_indices.at({k, 0});
// for (int k = 0; k < problem_size.k(); ++k) {
// tensor_d_ref.at({i, b_c_d_col}) +=
// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
// }
// }
// }
//
// Note that the index vector contains unique random integers with max to be N - 1
//
Expand Down Expand Up @@ -257,7 +257,7 @@ using Gemm = cutlass::gemm::device::GemmUniversal<ElementInputA,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true, /*GatherA*/
false, /*GatherA*/
true, /*GatherB*/
true /*ScatterD*/
>;
Expand All @@ -273,13 +273,13 @@ int run(Options &options) {
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size_real(problem_size.m(),
options.index_size,
options.index_size);
problem_size.k());

// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
cutlass::make_Coord(options.index_size, problem_size.n())); // <- Create matrix B with dimensions K x N
problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d_scattered(
Expand Down Expand Up @@ -353,7 +353,7 @@ int run(Options &options) {
tensor_b.layout().stride(),
tensor_c.layout().stride(),
tensor_d_scattered.layout().stride(),
tensor_indices.device_data(), // <- pointer to index vector to gather A on device
nullptr, // <- pointer to index vector to gather A on device
tensor_indices.device_data(), // <- pointer to index vector to gather B on device
tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device

Expand Down Expand Up @@ -388,10 +388,9 @@ int run(Options &options) {
for (int j = 0; j < options.index_size; ++j) {
int b_c_d_col = tensor_indices.at({j, 0});

for (int k = 0; k < options.index_size; ++k) {
int a_col = tensor_indices.at({k, 0});
for (int k = 0; k < problem_size.k(); ++k) {
tensor_d_ref.at({i, b_c_d_col}) +=
alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
}

tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
Expand Down

0 comments on commit 5d05808

Please sign in to comment.