Skip to content

Commit

Permalink
add long type for jagged op. (pytorch#1214)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1214

as title.

Reviewed By: jiaqizhai, brad-mengchi, mjanderson09

Differential Revision: D37978119

fbshipit-source-id: c2c004dfb2e1483f6fbf6a415c9dd58d95599cb0
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Jul 22, 2022
1 parent 0551586 commit 49061a2
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,12 @@ class DenseToJaggedGPUOp
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dense.get_device());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
values.scalar_type(), "jagged_dense_add_forward", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::Long,
values.scalar_type(),
"jagged_dense_add_forward",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
values,
offsets,
Expand Down

0 comments on commit 49061a2

Please sign in to comment.