Skip to content

Commit

Permalink
Optimize the inner most loop for jagged tensor implementation (pytorc…
Browse files Browse the repository at this point in the history
…h#1041)

Summary:
Pull Request resolved: pytorch#1041

Optimize the inner most loop:

For FP16, we prefer using 128 Byte access per warp (32 threads): cache line size is 128 Bytes on A100 GPUs.

Reviewed By: jasonjk-park

Differential Revision: D35532377

fbshipit-source-id: bcb7e82cd817b90203d78244f78597c9f8f41b7e
  • Loading branch information
jianyuh authored and facebook-github-bot committed Apr 12, 2022
1 parent 9147ea2 commit eacd342
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
17 changes: 13 additions & 4 deletions fbgemm_gpu/bench/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,27 @@ def cli() -> None:


@cli.command()
@click.option("--batch-size", default=128)
@click.option("--embedding-dim", default=128)
@click.option("--max-len", default=128)
@click.option("--batch-size", type=int, default=128)
@click.option("--embedding-dim", type=int, default=128)
@click.option("--max-len", type=int, default=128)
@click.option("--elem-type", type=str, default="half")
def device(
batch_size: int,
embedding_dim: int,
max_len: int,
elem_type: str,
) -> None:
lengths = torch.randint(max_len, size=(batch_size,))
total_lengths = lengths.sum().item()
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
values_2d = torch.rand(total_lengths, embedding_dim)

dtype = (
torch.float16
if elem_type == "half" or elem_type == "float16"
else torch.float32
)

values_2d = torch.rand(total_lengths, embedding_dim, dtype=dtype)

if torch.cuda.is_available():
offsets = offsets.cuda()
Expand Down
38 changes: 30 additions & 8 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,30 @@ __launch_bounds__(kMaxThreads) void jagged_dense_elementwise_dense_output_kernel
offset, jidx, jagged_dims, x_offsets);

if (is_zero) {
for (int iidx = threadIdx.x; iidx < inner_dense_size;
int iidx;
for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size;
iidx += blockDim.x) {
output[oidx][jidx][iidx] = f(padding_value, y[oidx][jidx][iidx]);
output[oidx][jidx][2 * iidx] =
f(padding_value, y[oidx][jidx][2 * iidx]);
output[oidx][jidx][2 * iidx + 1] =
f(padding_value, y[oidx][jidx][2 * iidx + 1]);
}
if (iidx * 2 + 1 == inner_dense_size) {
output[oidx][jidx][2 * iidx] =
f(padding_value, y[oidx][jidx][2 * iidx]);
}
} else {
for (int iidx = threadIdx.x; iidx < inner_dense_size;
int iidx;
for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size;
iidx += blockDim.x) {
output[oidx][jidx][iidx] =
f(x_values[offset][iidx], y[oidx][jidx][iidx]);
output[oidx][jidx][2 * iidx] =
f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]);
output[oidx][jidx][2 * iidx + 1] =
f(x_values[offset][2 * iidx + 1], y[oidx][jidx][2 * iidx + 1]);
}
if (iidx * 2 + 1 == inner_dense_size) {
output[oidx][jidx][2 * iidx] =
f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]);
}
}
}
Expand Down Expand Up @@ -265,10 +280,17 @@ __launch_bounds__(kMaxThreads) void jagged_dense_elementwise_jagged_output_kerne
offset, jidx, jagged_dims, x_offsets);

if (!is_zero) {
for (int iidx = threadIdx.x; iidx < inner_dense_size;
int iidx;
for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size;
iidx += blockDim.x) {
output_values[offset][iidx] =
f(x_values[offset][iidx], y[oidx][jidx][iidx]);
output_values[offset][2 * iidx] =
f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]);
output_values[offset][2 * iidx + 1] =
f(x_values[offset][2 * iidx + 1], y[oidx][jidx][2 * iidx + 1]);
}
if (iidx * 2 + 1 == inner_dense_size) {
output_values[offset][2 * iidx] =
f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]);
}
}
}
Expand Down

0 comments on commit eacd342

Please sign in to comment.