Skip to content

Commit

Permalink
Fix pack_segments backward when grad is non-contig (pytorch#3222)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3222

X-link: facebookresearch/FBGEMM#320

Original commit changeset: c1fe80d75fb4

Original Phabricator Diff: D61694017

from D61694017

Reviewed By: q10, brad-mengchi

Differential Revision: D63424805

fbshipit-source-id: 42e44383b48a577610f00ad6b8c2cd48bf734a2b
  • Loading branch information
spcyppt authored and facebook-github-bot committed Oct 4, 2024
1 parent 342e8d2 commit 88ef5f9
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 8 deletions.
16 changes: 10 additions & 6 deletions fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,21 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda(

CUDA_DEVICE_GUARD(data);

const auto data_contig = data.expect_contiguous();

Tensor unpacked_tensor; // The output tensor

AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "unpack_segments_cuda", [&] {
const auto* const lengths_data = lengths.data_ptr<index_t>();

// Create output tensor of appropriate dimensions
auto shape = data.sizes().vec();
auto shape = data_contig->sizes().vec();
shape.erase(shape.begin());
shape[0] = total_length;
unpacked_tensor = at::empty(shape, data.options());
unpacked_tensor = at::empty(shape, data_contig->options());

if (!(data.size(0) && data.size(1))) { // TODO: What does this mean?
if (!(data_contig->size(0) &&
data_contig->size(1))) { // TODO: What does this mean?
return;
}

Expand All @@ -82,10 +85,11 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda(
auto lps_data = lengths_prefix_sum.data_ptr<index_t>();

FBGEMM_DISPATCH_ALL_TYPES(
data.scalar_type(), "unpack_segments_cuda-unpacking", [&] {
data_contig->scalar_type(), "unpack_segments_cuda-unpacking", [&] {
const auto num_seq = lengths.size(0);
const auto cell_size = data.numel() / (data.size(0) * data.size(1));
const auto* const data_ptr = data.data_ptr<scalar_t>();
const auto cell_size = data_contig->numel() /
(data_contig->size(0) * data_contig->size(1));
const auto* const data_ptr = data_contig->data_ptr<scalar_t>();
auto* const out_data = unpacked_tensor.data_ptr<scalar_t>();

unpack_segments_cuda_kernel<index_t, scalar_t>
Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/test/sparse/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
"_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit",
"_version": 1,
"data": {
"fb::pack_segments": {
"PackedSegmentsTest.test_aot_dispatch_dynamic__test_pack_segments_noncontig": {
"comment": "",
"status": "xfail"
},
"PackedSegmentsTest.test_faketensor__test_pack_segments_noncontig": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::asynchronous_complete_cumsum": {},
"fbgemm::asynchronous_exclusive_cumsum": {},
"fbgemm::asynchronous_inclusive_cumsum": {},
Expand Down
174 changes: 172 additions & 2 deletions fbgemm_gpu/test/sparse/pack_segments_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

if open_source:
# pyre-ignore[21]
from test_utils import gpu_available
from test_utils import gpu_available, gpu_unavailable
else:
from fbgemm_gpu.test.test_utils import gpu_available
from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable


def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray:
Expand All @@ -47,6 +47,15 @@ def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray:
# pyre-fixme[2]
# pyre-fixme[24]
def torch_compiled(model: Callable, **kwargs) -> Callable:
"""A helper function to apply torch.compile if python < 3.12.
Args:
model: The model to be compiled.
kwargs: The arguments to be passed to torch.compile.
Returns:
The model.
"""
if sys.version_info < (3, 12, 0):
return torch.compile(model, **kwargs)
else:
Expand All @@ -60,6 +69,17 @@ def _pack_segments_ref(
tensor: torch.Tensor,
max_length: Optional[int] = None,
) -> npt.NDArray:
"""
This function is a reference implementation of pack_segments.
Args:
lengths (Tensor): The lengths of tensor.
tensor (Tensor): The tensor to be packed.
max_length (Optional[int]): The maximum length of the packed tensor.
Returns:
The packed tensor.
"""
lengths = lengths.numpy()
sections = np.split(tensor, np.cumsum(lengths))
max_length = np.max(lengths, initial=0) if max_length is None else max_length
Expand Down Expand Up @@ -106,6 +126,22 @@ def test_pack_segments(
dtype: torch.dtype,
torch_compile: bool,
) -> None:
"""
This function tests pack_segments ops compared to the reference implementation.
Both CPU and GPU (if available) are tested.
Args:
n - The number of rows in the input tensor
k - The number of columns in the input tensor
batch_size - The number of batches of the input tensor
divisions - The number of segments to be packed
dtype - The data type
torch_compile - Whether to use torch.compile
Returns:
None
"""

input_raw = np.random.rand(batch_size, n, k)
input_data = torch.tensor(input_raw, dtype=dtype, requires_grad=True)
lengths = torch.tensor(
Expand Down Expand Up @@ -209,6 +245,23 @@ def test_pack_segments_smaller_max_len(
dtype: torch.dtype,
torch_compile: bool,
) -> None:
"""
This function tests pack_segments ops with set max_length
Both CPU and GPU (if available) are tested.
Args:
n - The number of rows in the input tensor
k - The number of columns in the input tensor
batch_size - The number of batches of the input tensor
divisions - The number of segments to be packed
max_length - The maximum length of the packed tensor
dtype - The data type
torch_compile - Whether to use torch.compile
Returns:
None
"""

input_raw = np.random.rand(batch_size, n, k)
input_data = torch.tensor(input_raw, dtype=dtype)
lengths = torch.tensor(
Expand Down Expand Up @@ -264,6 +317,20 @@ def test_pack_segments_meta_backend(
divisions: int,
dtype: torch.dtype,
) -> None:
"""
This function tests pack_segments ops with meta backend.
Args:
n - The number of rows in the input tensor
k - The number of columns in the input tensor
batch_size - The number of batches of the input tensor
divisions - The number of segments to be packed
dtype - The data type
Returns:
None
"""

input_raw = np.random.rand(batch_size, n, k)
input_data = torch.tensor(
input_raw, dtype=torch.float32, requires_grad=True
Expand All @@ -281,6 +348,109 @@ def test_pack_segments_meta_backend(
# verify forward
assert packed_tensor.size() == torch.Tensor(packed_ref).size()

@unittest.skipIf(*gpu_unavailable)
@given(
n=st.integers(2, 10),
k=st.integers(2, 10),
batch_size=st.integers(1, 30),
divisions=st.integers(1, 10),
dtype=st.sampled_from(
[
torch.float,
torch.half,
]
),
torch_compile=st.booleans(),
use_cpu=st.booleans(),
)
@settings(deadline=None)
def test_pack_segments_noncontig(
self,
n: int,
k: int,
batch_size: int,
divisions: int,
dtype: torch.dtype,
torch_compile: bool,
use_cpu: bool,
) -> None:
"""
This function tests pack_segments ops when input gradients to backward are non-contiguous.
Args:
n - The number of rows in the input tensor
k - The number of columns in the input tensor
batch_size - The number of batches of the input tensor
divisions - The number of segments to be packed
dtype - The data type
torch_compile - Whether to use torch.compile
use_cpu - Whether to use CPU or GPU
Returns:
None
"""

input_raw = np.random.rand(batch_size, n, k)
# create input
input_data_ref = torch.tensor(input_raw, dtype=dtype, requires_grad=True)
input_data = torch.tensor(input_raw, dtype=dtype, requires_grad=True).cuda()
# retain grad to compare gradients of the inputs later
input_data.retain_grad()
input_data_ref.retain_grad()

# set lengths
lengths = torch.tensor(
get_n_rand_num_summing_to_k(divisions, batch_size),
dtype=torch.int,
)
max_length = lengths.max().item()

packed_ref = torch.ops.fbgemm.pack_segments(
t_in=input_data_ref, lengths=lengths, max_length=max_length
)
packed_ref.retain_grad()

# pack segments using fbgemm and fb
packed_tensor = torch.ops.fbgemm.pack_segments(
t_in=input_data, lengths=lengths.cuda(), max_length=max_length
)
packed_tensor.retain_grad()

# verify forward
self.assertTrue(torch.equal(packed_tensor.cpu(), packed_ref))

# create non-contiguous grad
shape = tuple(x * 2 for x in packed_ref.shape)
grads = torch.tensor(
np.random.uniform(low=0.01, high=0.5, size=shape).astype(np.float32)
).to(dtype)
grad_noncontig_cpu = grads.as_strided(packed_ref.shape, grads.stride())
grad_noncontig_cuda = grads.cuda().as_strided(packed_ref.shape, grads.stride())

self.assertTrue(
not (
grad_noncontig_cpu.is_contiguous()
and grad_noncontig_cuda.is_contiguous()
),
msg="Expected grads to be non-contiguous but they are contiguous",
)

# verify backward
packed_ref.backward(grad_noncontig_cpu)
packed_tensor.backward(grad_noncontig_cuda)
self.assertTrue(
torch.equal(packed_tensor.cpu(), packed_ref),
msg="Expected packed tensors to be equal but they are not",
)

# verify backward input gradients
self.assertTrue(
# pyre-fixme[16]: Optional type has no attribute `cpu`.
# pyre-fixme[6]: For 2nd param expected `Tensor` but got `Optional[Tensor]`.
torch.equal(input_data.grad.cpu(), input_data_ref.grad.cpu()),
msg="Expected input gradients to be equal but they are not",
)


extend_test_class(PackedSegmentsTest)

Expand Down

0 comments on commit 88ef5f9

Please sign in to comment.