diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_operators.py b/fbgemm_gpu/fbgemm_gpu/sparse_operators.py index c561857335..091792ac19 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_operators.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_operators.py @@ -45,3 +45,29 @@ def permute_2D_sparse_data_meta( # pyre-fixme permuted_weights = weights.new_empty(permuted_indices_size) return permuted_lengths, permuted_indices, permuted_weights + + +@torch.library.impl_abstract("fbgemm::permute_1D_sparse_data") +def permute_1D_sparse_data_meta( + permute: Tensor, + lengths: Tensor, + values: Tensor, + weights: Optional[Tensor] = None, + permuted_lengths_sum: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + indices = values + permuted_lengths_size = permute.numel() + permuted_lengths = lengths.new_empty([permuted_lengths_size]) + permuted_indices_size = 0 + if permuted_lengths_sum is not None: + permuted_indices_size = permuted_lengths_sum + else: + ctx = torch._custom_op.impl.get_ctx() + permuted_indices_size = ctx.new_dynamic_size() + # pyre-fixme + permuted_indices = indices.new_empty(permuted_indices_size) + permuted_weights = None + if weights is not None: + # pyre-fixme + permuted_weights = weights.new_empty(permuted_indices_size) + return permuted_lengths, permuted_indices, permuted_weights diff --git a/fbgemm_gpu/test/failures_dict.json b/fbgemm_gpu/test/failures_dict.json index de0ae7d4c7..daca769e1e 100644 --- a/fbgemm_gpu/test/failures_dict.json +++ b/fbgemm_gpu/test/failures_dict.json @@ -516,18 +516,10 @@ } }, "fbgemm::permute_1D_sparse_data": { - "SparseOpsTest.test_aot_dispatch_dynamic__test_permute_indices": { - "comment": "", - "status": "xfail" - }, "SparseOpsTest.test_aot_dispatch_static__test_permute_indices": { "comment": "", "status": "xfail" }, - "SparseOpsTest.test_faketensor__test_permute_indices": { - "comment": "", - "status": "xfail" - }, "SparseOpsTest.test_schema__test_permute_indices": { "comment": "flaky", "status": "skip"