Skip to content

Commit

Permalink
Migrate array_jagged_bmm_jagged_out SLL op to OSS (pytorch#3456)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#540

Pull Request resolved: pytorch#3456

- Migrate `array_jagged_bmm_jagged_out` SLL op to OSS

Reviewed By: brad-mengchi

Differential Revision: D66790979

fbshipit-source-id: 51b197322573f82738810177d20ec8aa1cbdd658
  • Loading branch information
q10 authored and facebook-github-bot committed Dec 6, 2024
1 parent 03129ae commit 837b14f
Show file tree
Hide file tree
Showing 8 changed files with 1,141 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ __configure_fbgemm_gpu_test_cpu () {
./uvm/copy_test.py
./uvm/uvm_test.py
./sll/triton_sll_test.py
./sll/array_jagged_bmm_jagged_out_test.py
)
}

Expand Down
35 changes: 34 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch

from fbgemm_gpu.sll.cpu_sll import ( # noqa F401
cpu_array_jagged_bmm_jagged_out,
cpu_dense_jagged_cat_jagged_out,
cpu_jagged2_softmax,
cpu_jagged2_to_padded_dense,
Expand All @@ -24,9 +25,13 @@
meta_jagged_self_substraction_jagged_out,
)

from fbgemm_gpu.sll.meta_sll import meta_jagged2_softmax # noqa F401
from fbgemm_gpu.sll.meta_sll import ( # noqa F401
meta_array_jagged_bmm_jagged_out,
meta_jagged2_softmax,
)

from fbgemm_gpu.sll.triton_sll import ( # noqa F401
array_jagged_bmm_jagged_out,
dense_jagged_cat_jagged_out,
jagged2_softmax,
jagged2_to_padded_dense,
Expand Down Expand Up @@ -175,6 +180,23 @@ def register_sll_op(op_name: str, functors: Dict[str, Callable]) -> None:
"""
)

if "fbgemm::array_jagged_bmm_jagged_out" not in torch.library._defs:
lib.define(
"""array_jagged_bmm_jagged_out(
Tensor x,
Tensor y,
Tensor x_lengths,
Tensor x_offsets,
Tensor y_lengths,
Tensor y_offsets,
Tensor z_lengths,
Tensor z_offsets,
int max_seq_len,
bool allow_tf32
) -> Tensor
"""
)

# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same function
# however, this is not ideal because in the inference case, we don't need the autograd forward
# to save the context because we don't need to do backward.
Expand Down Expand Up @@ -256,3 +278,14 @@ def register_sll_op(op_name: str, functors: Dict[str, Callable]) -> None:
"AutogradMeta": meta_jagged2_softmax,
},
)

register_sll_op(
"array_jagged_bmm_jagged_out",
{
"CUDA": array_jagged_bmm_jagged_out,
"AutogradCUDA": array_jagged_bmm_jagged_out,
"CPU": cpu_array_jagged_bmm_jagged_out,
"AutogradCPU": cpu_array_jagged_bmm_jagged_out,
"AutogradMeta": meta_array_jagged_bmm_jagged_out,
},
)
216 changes: 216 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,219 @@ def cpu_jagged2_softmax(
max_seq_len,
transpose,
)


# pyre-fixme[3]: Return type must be annotated.
def cpu_jagged_jagged_bmm_jagged_out_kernel(
# pyre-fixme[2]: Parameter must be annotated.
jagged_A,
# pyre-fixme[2]: Parameter must be annotated.
jagged_B,
# pyre-fixme[2]: Parameter must be annotated.
max_seq_len,
# pyre-fixme[2]: Parameter must be annotated.
lengths_m,
# pyre-fixme[2]: Parameter must be annotated.
lengths_n,
# pyre-fixme[2]: Parameter must be annotated.
lengths_mn,
# pyre-fixme[2]: Parameter must be annotated.
offsets_m,
# pyre-fixme[2]: Parameter must be annotated.
offsets_n,
# pyre-fixme[2]: Parameter must be annotated.
offsets_mn,
# pyre-fixme[2]: Parameter must be annotated.
allow_tf32=False,
):
jagged_C = torch.empty((int(lengths_mn.sum().item())), dtype=jagged_A.dtype).to(
jagged_A.device
)
B = lengths_m.size(0)

for i in range(B):
jagged_C[offsets_mn[i] : offsets_mn[i + 1]] = torch.matmul(
jagged_A[offsets_m[i] : offsets_m[i + 1]],
jagged_B[offsets_n[i] : offsets_n[i + 1]].T,
).flatten()
return jagged_C


# pyre-fixme[3]: Return type must be annotated.
def cpu_array_jagged_bmm_jagged_out_kernel(
# pyre-fixme[2]: Parameter must be annotated.
array_A,
# pyre-fixme[2]: Parameter must be annotated.
jagged_B,
# pyre-fixme[2]: Parameter must be annotated.
lengths_am,
# pyre-fixme[2]: Parameter must be annotated.
lengths_bk,
# pyre-fixme[2]: Parameter must be annotated.
lengths_cm,
# pyre-fixme[2]: Parameter must be annotated.
offsets_am,
# pyre-fixme[2]: Parameter must be annotated.
offsets_bk,
# pyre-fixme[2]: Parameter must be annotated.
offsets_cm,
# pyre-fixme[2]: Parameter must be annotated.
max_seq_len,
# pyre-fixme[2]: Parameter must be annotated.
allow_tf32=False,
# pyre-fixme[2]: Parameter must be annotated.
transpose=0, # one if a is transpose, otherwise zero
):
B = lengths_am.size(0)
D = jagged_B.size(1)
jagged_C = torch.zeros(
(int(lengths_cm.sum()), D), device=jagged_B.device, dtype=jagged_B.dtype
)

for i in range(B):
seq_len = int(lengths_bk[i])
capped_seq_len = min(seq_len, max_seq_len)
a = array_A[offsets_am[i] : offsets_am[i + 1]].view(seq_len, seq_len)
a = a[:capped_seq_len, :capped_seq_len]

if transpose:
a = a.T
b = jagged_B[offsets_bk[i] : offsets_bk[i] + capped_seq_len]
jagged_C[offsets_cm[i] : offsets_cm[i] + capped_seq_len] = torch.matmul(a, b)

return jagged_C


class ArrayJaggedBmmNopaddingCPU(torch.autograd.Function):
"""
Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding.
z = X * Y
x: [Sum_B(N_i, N_i)]
y: [sum_B(N_i), D]
z: [sum_B(N_i), D]
"""

@staticmethod
# pyre-fixme
def forward(
# pyre-fixme[2]: Parameter must be annotated.
ctx,
x: torch.Tensor,
y: torch.Tensor,
x_lengths: torch.Tensor,
x_offsets: torch.Tensor,
y_lengths: torch.Tensor,
y_offsets: torch.Tensor,
z_lengths: torch.Tensor,
z_offsets: torch.Tensor,
max_seq_len: int,
# pyre-fixme[2]: Parameter must be annotated.
allow_tf32,
):
ctx.allow_tf32 = allow_tf32
ctx.max_seq_len = max_seq_len

ctx.save_for_backward(
x,
y,
x_lengths,
y_lengths,
z_lengths,
x_offsets,
y_offsets,
z_offsets,
)

return cpu_array_jagged_bmm_jagged_out_kernel(
x,
y,
x_lengths,
y_lengths,
z_lengths,
x_offsets,
y_offsets,
z_offsets,
max_seq_len,
allow_tf32,
0,
)

@staticmethod
# pyre-fixme
def backward(ctx, grad_output: torch.Tensor):
"""
z = X * Y
dX = dZ * YT
dY = XT * dZ
dZ: [sum_B(N_i), D]
YT: [D, sum_B(N_i)] call Y.T
XT: transposed
Z: [sum_B(N_i), D]
"""

(
x,
y,
x_lengths,
y_lengths,
z_lengths,
x_offsets,
y_offsets,
z_offsets,
) = ctx.saved_tensors

grad_x = cpu_jagged_jagged_bmm_jagged_out_kernel(
grad_output,
y,
ctx.max_seq_len,
z_lengths,
y_lengths,
x_lengths,
z_offsets,
y_offsets,
x_offsets,
ctx.allow_tf32,
)

grad_y = cpu_array_jagged_bmm_jagged_out_kernel(
x,
grad_output,
x_lengths,
y_lengths,
z_lengths,
x_offsets,
y_offsets,
z_offsets,
ctx.max_seq_len,
ctx.allow_tf32,
1,
)
return grad_x, grad_y, None, None, None, None, None, None, None, None


# pyre-fixme[3]: Return type must be annotated.
def cpu_array_jagged_bmm_jagged_out(
x: torch.Tensor,
y: torch.Tensor,
x_lengths: torch.Tensor,
x_offsets: torch.Tensor,
y_lengths: torch.Tensor,
y_offsets: torch.Tensor,
z_lengths: torch.Tensor,
z_offsets: torch.Tensor,
max_seq_len: int,
allow_tf32: bool = True,
):
return ArrayJaggedBmmNopaddingCPU.apply(
x,
y,
x_lengths,
x_offsets,
y_lengths,
y_offsets,
z_lengths,
z_offsets,
max_seq_len,
allow_tf32,
)
Loading

0 comments on commit 837b14f

Please sign in to comment.