Skip to content

Commit

Permalink
add pt2_compliant tag to some ops (pytorch#2119)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#113201

Pull Request resolved: pytorch#2119

Logs show these ops are being used with PT2, so we are grandfathering in these
ops to the pt2_compliant tag. Most of these ops are tested, some aren't.

bypass-github-export-checks

Reviewed By: ezyang

Differential Revision: D51076460

fbshipit-source-id: b08efb10fef0a0437a6c09cf0ac7f374f3b308ab
  • Loading branch information
zou3519 authored and facebook-github-bot committed Nov 10, 2023
1 parent 293e500 commit 2117dd3
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 30 deletions.
7 changes: 7 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,10 @@
#define FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16_CASE(__VA_ARGS__))

// We can cleanup the following once fbgemm uses PyTorch 2.2 in January 2024.
#ifdef HAS_PT2_COMPLIANT_TAG
#define PT2_COMPLIANT_TAG at::Tag::pt2_compliant_tag
#else
#define PT2_COMPLIANT_TAG
#endif
60 changes: 40 additions & 20 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1574,13 +1574,17 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// details at https://pytorch.org/get-started/pytorch-2.0/#dynamic-shapes. If
// you find it doesn't compile, please pull the new PyTorch 2.0 code
m.def(
"dense_to_jagged(Tensor dense, Tensor[] x_offsets, SymInt? total_L=None) -> (Tensor, Tensor[])");
"dense_to_jagged(Tensor dense, Tensor[] x_offsets, SymInt? total_L=None) -> (Tensor, Tensor[])",
{PT2_COMPLIANT_TAG});
m.def(
"dense_to_jagged_forward(Tensor dense, Tensor[] x_offsets, SymInt? total_L=None) -> Tensor");
"dense_to_jagged_forward(Tensor dense, Tensor[] x_offsets, SymInt? total_L=None) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_2d_to_dense(Tensor values, Tensor offsets, SymInt max_sequence_length) -> Tensor");
"jagged_2d_to_dense(Tensor values, Tensor offsets, SymInt max_sequence_length) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_1d_to_dense(Tensor values, Tensor offsets, SymInt max_sequence_length, int padding_value) -> Tensor");
"jagged_1d_to_dense(Tensor values, Tensor offsets, SymInt max_sequence_length, int padding_value) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"stacked_jagged_2d_to_dense_forward(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key, int padding_value = 0) -> (Tensor[], Tensor[])");
m.def(
Expand All @@ -1590,35 +1594,48 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"stacked_jagged_2d_to_dense(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key, int padding_value = 0) -> Tensor[]");
m.def(
"jagged_to_padded_dense(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor");
"jagged_to_padded_dense(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor");
"jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_to_padded_dense_backward(Tensor grad_output, Tensor[] offsets, SymInt total_L) -> Tensor");
"jagged_to_padded_dense_backward(Tensor grad_output, Tensor[] offsets, SymInt total_L) -> Tensor",
{PT2_COMPLIANT_TAG});
// jagged + dense -> dense
m.def(
"jagged_dense_elementwise_add(Tensor x_values, Tensor[] x_offsets, Tensor y) -> Tensor");
"jagged_dense_elementwise_add(Tensor x_values, Tensor[] x_offsets, Tensor y) -> Tensor",
{PT2_COMPLIANT_TAG});
// jagged + dense -> jagged (treat "zeros" in the jagged tensor as unknowns.
// output offsets is same as x_offsets)
m.def(
"jagged_dense_elementwise_add_jagged_output(Tensor x_values, Tensor[] x_offsets, Tensor y) -> (Tensor, Tensor[])");
"jagged_dense_elementwise_add_jagged_output(Tensor x_values, Tensor[] x_offsets, Tensor y) -> (Tensor, Tensor[])",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_dense_dense_elementwise_add_jagged_output_forward(Tensor x_values, Tensor[] x_offsets, Tensor y_0, Tensor y_1) -> Tensor");
"jagged_dense_dense_elementwise_add_jagged_output_forward(Tensor x_values, Tensor[] x_offsets, Tensor y_0, Tensor y_1) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_dense_dense_elementwise_add_jagged_output(Tensor x_values, Tensor[] x_offsets, Tensor y_0, Tensor y_1) -> (Tensor, Tensor[])");
"jagged_dense_dense_elementwise_add_jagged_output(Tensor x_values, Tensor[] x_offsets, Tensor y_0, Tensor y_1) -> (Tensor, Tensor[])",
{PT2_COMPLIANT_TAG});
// jagged * dense -> jagged (its offsets is same as x_offsets)
m.def(
"jagged_dense_elementwise_mul(Tensor x_values, Tensor[] x_offsets, Tensor y) -> (Tensor, Tensor[])");
"jagged_dense_elementwise_mul(Tensor x_values, Tensor[] x_offsets, Tensor y) -> (Tensor, Tensor[])",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_dense_elementwise_mul_forward(Tensor x_values, Tensor[] x_offsets, Tensor y) -> Tensor");
"jagged_dense_elementwise_mul_forward(Tensor x_values, Tensor[] x_offsets, Tensor y) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_dense_elementwise_mul_backward(Tensor grad_output, Tensor[] x_offsets, Tensor y, Tensor x_values) -> (Tensor, Tensor)");
"jagged_dense_elementwise_mul_backward(Tensor grad_output, Tensor[] x_offsets, Tensor y, Tensor x_values) -> (Tensor, Tensor)",
{PT2_COMPLIANT_TAG});
m.def(
"batched_dense_vec_jagged_2d_mul(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor");
"batched_dense_vec_jagged_2d_mul(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"batched_dense_vec_jagged_2d_mul_forward(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor");
"batched_dense_vec_jagged_2d_mul_forward(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"batched_dense_vec_jagged_2d_mul_backward(Tensor grad_output, Tensor v, Tensor a_values, Tensor a_offsets) -> (Tensor, Tensor)");
"batched_dense_vec_jagged_2d_mul_backward(Tensor grad_output, Tensor v, Tensor a_values, Tensor a_offsets) -> (Tensor, Tensor)",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_index_select(Tensor values, Tensor lengths, Tensor indices) -> Tensor[]");
m.def(
Expand All @@ -1630,17 +1647,20 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"masked_select_jagged_1d(Tensor values, Tensor lengths, Tensor mask) -> (Tensor, Tensor)");
m.def(
"jagged_softmax(Tensor values, Tensor x_offsets, int max_L) -> (Tensor, Tensor)");
"jagged_softmax(Tensor values, Tensor x_offsets, int max_L) -> (Tensor, Tensor)",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_softmax_forward(Tensor values, Tensor x_offsets, int max_L) -> Tensor");
m.def(
"jagged_softmax_backward(Tensor grad_output, Tensor output, Tensor x_offsets, int max_L) -> Tensor");
m.def(
"jagged_jagged_bmm(Tensor x_values, Tensor y_values, Tensor x_offsets, int max_L) -> Tensor");
"jagged_jagged_bmm(Tensor x_values, Tensor y_values, Tensor x_offsets, int max_L) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_jagged_bmm_forward(Tensor x_values, Tensor y_values, Tensor x_offsets, int max_L) -> Tensor");
m.def(
"jagged_dense_bmm(Tensor x_values, Tensor x_offsets, Tensor y, int max_L) -> (Tensor, Tensor)");
"jagged_dense_bmm(Tensor x_values, Tensor x_offsets, Tensor y, int max_L) -> (Tensor, Tensor)",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_dense_bmm_forward(Tensor x_values, Tensor x_offsets, Tensor y, int max_L) -> Tensor");
// jagged -> jagged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ATen/core/op_registration/op_registration.h>
#include <c10/core/TensorOptions.h>
#include <torch/library.h>
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/ops_utils.h"

using Tensor = at::Tensor;
Expand Down Expand Up @@ -56,7 +57,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");
#endif
m.def(
"merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor");
"merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"all_to_one_device(Tensor[] input_tensors, Device target_device) -> Tensor[]");
m.def(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <c10/util/irange.h>
#include <vector>
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/permute_pooled_embedding_ops.h"

using Tensor = at::Tensor;
Expand Down Expand Up @@ -149,7 +150,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"permute_pooled_embs(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
m.def(
"permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
"permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"permute_duplicate_pooled_embs(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
m.def(
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("FloatOrHalfToFused8BitRowwiseQuantized(Tensor t) -> Tensor");
m.def("Fused8BitRowwiseQuantizedToFloat(Tensor input) -> Tensor");
m.def(
"FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor");
"FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def("Fused8BitRowwiseQuantizedToHalf(Tensor input) -> Tensor");
m.def(
"Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0) -> Tensor");
Expand Down
19 changes: 14 additions & 5 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/csrc/autograd/custom_function.h>
#include "c10/util/MaybeOwned.h"
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

Expand Down Expand Up @@ -2704,9 +2705,15 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
m.def(
"bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)");
m.def("asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor");
m.def("asynchronous_inclusive_cumsum(Tensor t_in) -> Tensor");
m.def("asynchronous_complete_cumsum(Tensor t_in) -> Tensor");
m.def(
"asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"asynchronous_inclusive_cumsum(Tensor t_in) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"asynchronous_complete_cumsum(Tensor t_in) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"reorder_batched_ad_lengths(Tensor cat_ad_lengths, Tensor batch_offsets, SymInt num_ads_in_batch, bool broadcast_lengths=False) -> Tensor");
m.def(
Expand All @@ -2715,7 +2722,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"cat_reorder_batched_ad_indices(Tensor cat_ad_offsets, Tensor[] cat_ad_indices, Tensor reordered_cat_ad_offsets, Tensor batch_offsets, SymInt num_ads_in_batch, bool broadcast_indices, SymInt total_num_indices, bool pinned_memory=False) -> Tensor");
m.def("offsets_range(Tensor offsets, SymInt range_size) -> Tensor");
m.def(
"batched_unary_embeddings(Tensor weight, Tensor table_offsets, Tensor offsets, Tensor indices) -> Tensor");
"batched_unary_embeddings(Tensor weight, Tensor table_offsets, Tensor offsets, Tensor indices) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"histogram_binning_calibration(Tensor logit, Tensor bin_num_examples, Tensor bin_num_positives, float positive_weight, float lower_bound, float upper_bound, SymInt bin_ctr_in_use_after, float bin_ctr_weight_value) -> (Tensor, Tensor)");
m.def(
Expand All @@ -2738,7 +2746,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"permute_sequence_embeddings(Tensor permute, Tensor lengths, Tensor embeddings) -> (Tensor, Tensor)");
m.def(
"pack_segments(Tensor t_in, Tensor lengths, SymInt max_length) -> Tensor");
"pack_segments(Tensor t_in, Tensor lengths, SymInt max_length) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"pack_segments_backward(Tensor data, Tensor lengths, SymInt total_length, SymInt max_length) -> Tensor");
// A specialization of at::index_select for selecting dim 0
Expand Down
24 changes: 22 additions & 2 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import itertools
import random
import unittest
from typing import List, Tuple
from typing import Callable, Dict, List, Tuple

import hypothesis.strategies as st
import numpy as np
Expand Down Expand Up @@ -127,7 +127,27 @@ def hash_size_cumsum_to_offsets(hash_size_cum_sum_list: List[int]) -> List[int]:
return hash_size_offsets_list


@optests.generate_opcheck_tests
# e.g. "test_faketensor__test_cumsum": [unittest.expectedFailure]
# Please avoid putting tests here, you should put operator-specific
# skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json
# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
additional_decorators: Dict[str, List[Callable]] = {
"test_pt2_compliant_tag_fbgemm_dense_to_jagged": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
],
"test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
],
"test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add_jagged_output": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
],
}


@optests.generate_opcheck_tests(additional_decorators=additional_decorators)
class JaggedTensorOpsTest(unittest.TestCase):
def setUp(self) -> None:
if symint_vector_unsupported()[0]:
Expand Down
12 changes: 12 additions & 0 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,6 +2417,18 @@ def validate(
"test_faketensor__test_index_select_dim0": [unittest.skip("hangs")],
"test_autograd_registration__test_index_select_dim0": [unittest.skip("hangs")],
"test_schema__test_index_select_dim0": [unittest.skip("hangs")],
"test_pt2_compliant_tag_fbgemm_dense_to_jagged": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
],
"test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
],
"test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add_jagged_output": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
],
}

# only generate tests on nightly pytorch (current release version is 2.1)
Expand Down

0 comments on commit 2117dd3

Please sign in to comment.