diff --git a/BUILD.bazel b/BUILD.bazel index 5468632ff119a3..eb77859fce7720 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -73,6 +73,7 @@ generated_cpu_cpp = [ "aten/src/ATen/NativeMetaFunctions.h", "aten/src/ATen/RegistrationDeclarations.h", "aten/src/ATen/core/aten_interned_strings.h", + "aten/src/ATen/core/enum_tag.h", "aten/src/ATen/core/TensorBody.h", "aten/src/ATen/core/TensorMethods.cpp", "aten/src/ATen/core/ATenOpList.cpp", diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 8d6447bffdfaca..3385bff63a11ae 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -147,7 +147,7 @@ void Dispatcher::deregisterLibrary_(const std::string& ns) { libraries_.erase(ns); } -RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug) { +RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug, std::vector tags) { // we need a lock to avoid concurrent writes std::lock_guard lock(mutex_); @@ -157,7 +157,7 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.", " Each overload's schema should only be registered with a single call to def().", " Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug()); - op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug)); + op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug), tags); listeners_->callOnOperatorRegistered(op); // NB: do not increment the counts until AFTER error checking diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 7f8e1532ae0b81..3ab0619c05ae90 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -14,6 +14,7 @@ #include #include +#include namespace c10 { @@ -177,7 +178,7 @@ class TORCH_API Dispatcher final { * If a schema with the same operator name and overload name already exists, * this function will check that both schemas are exactly identical. */ - RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug); + RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug, std::vector tags = {}); /** * Register a kernel to the dispatch table for an operator. @@ -338,6 +339,19 @@ class TORCH_API OperatorHandle { return operatorDef_->op.checkInvariants(); } + c10::ArrayRef getTags() const { + return operatorDef_->op.getTags(); + } + + bool hasTag(const at::Tag& tag) const { + for(const auto& tag_: getTags()) { + if (tag == tag_) { + return true; + } + } + return false; + } + template TypedOperatorHandle typed() const { // NB: This assert is not 100% sound: you can retrieve a typed() operator diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index d5cc6d45933fa2..5bbb391b2fcfe4 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -19,6 +19,7 @@ namespace { OperatorEntry::OperatorEntry(OperatorName&& operator_name) : name_(std::move(operator_name)) , schema_() +, tags_() , dispatchTable_() , dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized()) , kernels_() @@ -57,7 +58,7 @@ const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const { return kernel; } -void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug) { +void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector tags) { TORCH_INTERNAL_ASSERT(!schema_.has_value()); for (const auto& kernel : kernels_) { for (const auto &j : kernel.second) { @@ -69,6 +70,7 @@ void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug) // NB: don't register schema until after we've checked everything! dispatchKeyExtractor_.registerSchema(schema); schema_ = AnnotatedSchema(std::move(schema), std::move(debug)); + tags_ = std::move(tags); } void OperatorEntry::deregisterSchema() { @@ -208,6 +210,10 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat return nullptr; } +const std::vector& OperatorEntry::getTags() const { + return tags_; +} + std::pair OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const { // [Note] DispatchTable computation // dispatchTable contains entries for runtime dispatch keys. diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index c0f90808280a8e..dc24d72f32ce97 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -98,7 +99,7 @@ class TORCH_API OperatorEntry final { // attempt to register a schema when one is already present or vice // versa that is an error. (Refcounting for the registrations is // handled in the OperatorHandle in Dispatcher) - void registerSchema(FunctionSchema&&, std::string&& debug); + void registerSchema(FunctionSchema&&, std::string&& debug, std::vector tags = {}); void deregisterSchema(); const OperatorName& operator_name() const { @@ -205,12 +206,14 @@ class TORCH_API OperatorEntry final { bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const; // Returns true if kernel_ has entry for a particular key. bool hasKernelForDispatchKey(DispatchKey k) const; + // Returns all the operator tags added at the time of registration + const std::vector& getTags() const; private: OperatorName name_; c10::optional schema_; - + std::vector tags_; std::array dispatchTable_; DispatchKeyExtractor dispatchKeyExtractor_; diff --git a/aten/src/ATen/core/library.cpp b/aten/src/ATen/core/library.cpp index ba608e98ad53a8..5c9cea05ea76b8 100644 --- a/aten/src/ATen/core/library.cpp +++ b/aten/src/ATen/core/library.cpp @@ -89,7 +89,7 @@ Library::Library(Kind kind, std::string ns, c10::optional k, c // merge everything #define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): " -Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name) & { +Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector& tags) & { TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT, DEF_PRELUDE, "Cannot define an operator inside of a ", toString(kind_), " block. " @@ -128,7 +128,8 @@ Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name registrars_.emplace_back( c10::Dispatcher::singleton().registerDef( std::move(schema), - debugString(file_, line_) + debugString(file_, line_), + tags ) ); return *this; diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml index d79b13adae84d4..403dfd301c83eb 100644 --- a/aten/src/ATen/native/tags.yaml +++ b/aten/src/ATen/native/tags.yaml @@ -8,3 +8,7 @@ This tag indicates operators that are *_copy* variants of view/aliasing operators. If an operator has a view_copy tag, then it should have the name {op}_copy, where {op} is a view operator. +- tag: generated + desc: | + This tag indicates that the operator doesn't have an explicit entry in + native_functions.yaml, and instead was generated automatically by the codegen. diff --git a/aten/src/ATen/templates/enum_tag.h b/aten/src/ATen/templates/enum_tag.h new file mode 100644 index 00000000000000..1320fbc28ab8f7 --- /dev/null +++ b/aten/src/ATen/templates/enum_tag.h @@ -0,0 +1,10 @@ +#pragma once + +// ${generated_comment} + +namespace at { + // Enum of valid tags obtained from the entries in tags.yaml + enum class Tag { + ${enum_of_valid_tags} + }; +} diff --git a/build_variables.bzl b/build_variables.bzl index b78ed5512e04b4..5422fb0837e21c 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -957,6 +957,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"): "torch/csrc/autograd/generated/python_nn_functions.cpp", "torch/csrc/autograd/generated/python_fft_functions.cpp", "torch/csrc/autograd/generated/python_linalg_functions.cpp", + "torch/csrc/autograd/generated/python_enum_tag.cpp", "torch/csrc/autograd/generated/python_return_types.cpp", "torch/csrc/autograd/generated/python_sparse_functions.cpp", "torch/csrc/autograd/generated/python_special_functions.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index e21f0f34640c71..d5a8bde9207ffb 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -401,6 +401,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) "${TORCH_SRC_DIR}/csrc/autograd/generated/python_sparse_functions.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp" + "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp" ) set(GENERATED_H_PYTHON @@ -463,6 +464,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) "${TOOLS_PATH}/autograd/templates/python_sparse_functions.cpp" "${TOOLS_PATH}/autograd/templates/python_special_functions.cpp" "${TOOLS_PATH}/autograd/templates/python_return_types.cpp" + "${TOOLS_PATH}/autograd/templates/python_enum_tag.cpp" "${TOOLS_PATH}/autograd/templates/variable_factories.h" "${TOOLS_PATH}/autograd/templates/annotated_fn_args.py.in" "${TOOLS_PATH}/autograd/deprecated.yaml" diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 78c508802d49aa..e382bb63e245bf 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -613,6 +613,10 @@ Utilities vmap _assert +Operator Tags +------------------------------------ +.. autoclass:: Tag + :members: .. Empty submodules added only for tracking. .. py:module:: torch.contrib diff --git a/test/test_ops.py b/test/test_ops.py index 26e2f436c1d108..0d1220eeb5bbd0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -44,6 +44,7 @@ instantiate_device_type_tests, ops, onlyCUDA, + onlyCPU, onlyNativeDeviceTypes, OpDTypes, skipMeta, @@ -56,6 +57,7 @@ from torch.testing._internal import composite_compliance from torch.utils._pytree import tree_flatten +from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode # TODO: fixme https://github.com/pytorch/pytorch/issues/68972 torch.set_default_dtype(torch.float32) @@ -1355,6 +1357,57 @@ def is_bit_set(x): torch.is_complex, ) +# input strides and size may have been altered due to the result of an inplace op +def test_inplace_view(func, input, rs, input_size, input_strides): + if func is None: + return + # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out + # which mutate not necessarily the first input. + if isinstance(rs, torch.Tensor) and rs is input: + unequal_size = rs.size() != input_size + unequal_strides = rs.stride() != input_strides + # resize_ should probably have inplace_view tag. Not adding the tag since it + # breaks some codegen logic + if (unequal_size or unequal_strides): + if isinstance(func, torch._ops.OpOverloadPacket): + func = func.default + # Reference: https://github.com/pytorch/pytorch/issues/78759 + if func is not torch.ops.aten.resize_.default: + # TODO: use self.assertIn when we have separate tests for each tag + assert torch.Tag.inplace_view in func.tags + +# A mode that when enabled runs correctness checks to ensure +# that operators have expected tags based on their input and +# ouput tensor properties +class TestTagsMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if isinstance(args[0], torch.Tensor): + old_size = args[0].size() + old_stride = args[0].stride() + rs = func(*args, **kwargs) + test_inplace_view(func, args[0], rs, old_size, old_stride) + else: + rs = func(*args, **kwargs) + return rs + +# Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags` +class TestTags(TestCase): + @onlyCPU + @ops(ops_and_refs, dtypes=OpDTypes.any_one) + def test_tags(self, device, dtype, op): + samples = op.sample_inputs(device, dtype, requires_grad=False) + for sample in samples: + # TODO: Test tags for ops that return a list of tensors + input = sample.input + if isinstance(input, torch.Tensor): + old_size = input.size() + old_stride = input.stride() + with push_torch_dispatch_mode(TestTagsMode): + rs = op(input, *sample.args, **sample.kwargs) + # TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761 + aten_name = op.aten_name if op.aten_name is not None else op.name + opoverloadpacket = getattr(torch.ops.aten, aten_name, None) + test_inplace_view(opoverloadpacket, input, rs, old_size, old_stride) class TestRefsOpsInfo(TestCase): @@ -1394,6 +1447,7 @@ def test_refs_are_in_python_ref_db(self, op): instantiate_device_type_tests(TestCompositeCompliance, globals()) instantiate_device_type_tests(TestMathBits, globals()) instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu") +instantiate_device_type_tests(TestTags, globals()) if __name__ == "__main__": run_tests() diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 88c38a864662ec..91a9491ee8d84d 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -241,6 +241,10 @@ def test_no_new_bindings(self): "vitals_enabled", "wait", + "Tag", + "inplace_view", + "view_copy", + "generated" } torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")} diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 1b6d69f0d57197..11db980702ac1d 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -57,7 +57,7 @@ namedtuple_fieldnames, signature, ) -from torchgen.gen import cpp_string, parse_native_yaml +from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml from torchgen.context import with_native_function from torchgen.model import ( Argument, @@ -325,6 +325,17 @@ def gen( fm, functions, lambda fn: True, "python_return_types.cpp" ) + valid_tags = parse_tags_yaml(tags_yaml_path) + + def gen_tags_enum() -> Dict[str, str]: + return { + "enum_of_valid_tags": ( + "".join([f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags]) + ) + } + + fm.write("python_enum_tag.cpp", gen_tags_enum) + def group_filter_overloads( pairs: Sequence[PythonSignatureNativeFunctionPair], diff --git a/tools/autograd/templates/python_enum_tag.cpp b/tools/autograd/templates/python_enum_tag.cpp new file mode 100644 index 00000000000000..cec5ffabd1c7a7 --- /dev/null +++ b/tools/autograd/templates/python_enum_tag.cpp @@ -0,0 +1,15 @@ +#include +#include +#include + +namespace py = pybind11; +namespace torch { + namespace autograd { + void initEnumTag(PyObject* module) { + auto m = py::handle(module).cast(); + py::enum_(m, "Tag") + ${enum_of_valid_tags} + .export_values(); + m.doc() = "An Enum that contains tags that can be assigned to an operator registered in C++."; + } +}} diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index cf92530e447088..f0d6ddc7c8637d 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -202,7 +202,7 @@ def _jit_init() -> _bool: ... def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ... def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ... def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ... -def _get_operation_overload(op_name: str, op_overload_name: str) -> Callable: ... +def _get_operation_overload(op_name: str, op_overload_name: str) -> Tuple[Callable, List[Any]]: ... def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ... def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule', optimization_blocklist: Set[MobileOptimizerType], diff --git a/torch/_ops.py b/torch/_ops.py index 536a7b4f8e2c56..a325d93e517824 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -28,10 +28,11 @@ def dl_open_guard(): # Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object. # You can obtain an OpOverload object through attribute query on OpOverloadPacket. class OpOverload: - def __init__(self, overloadpacket, op, schema): + def __init__(self, overloadpacket, op, schema, tags): self._op = op self._schema = schema self._overloadpacket = overloadpacket + self._tags = tags self._overloadname = 'default' if schema.overload_name == '' else schema.overload_name self.__name__ = "{}.{}".format(self._schema.name.split("::")[1], self._overloadname) self.__module__ = overloadpacket.__module__ @@ -65,6 +66,10 @@ def overloadpacket(self): def op(self): return self._op + @property + def tags(self): + return self._tags + # TODO: add more methods to expose information about input and output arguments # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator @@ -123,10 +128,10 @@ def __getattr__(self, key): # This is ok since we are guaranteed that an overload name for an aten op can't be 'default' use_key = '' if key == 'default' else key # TODO: disallow access to overloads registered by JIT - op_ = torch._C._get_operation_overload( + op_, tags = torch._C._get_operation_overload( self._qualified_op_name, use_key) schema = torch._C._get_schema(self._qualified_op_name, use_key) - overload = OpOverload(self, op_, schema) + overload = OpOverload(self, op_, schema, tags) # cache the overload object setattr(self, key, overload) return overload diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index ff2def5b3d9e60..c7e9d93601f509 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -828,6 +829,7 @@ PyObject* initModule() { // the export side of JIT, so this ONNX init needs to appear before the JIT // init. torch::onnx::initONNXBindings(module); + torch::autograd::initEnumTag(module); torch::jit::initJITBindings(module); torch::monitor::initMonitorBindings(module); torch::impl::dispatch::initDispatchBindings(module); diff --git a/torch/csrc/autograd/python_enum_tag.h b/torch/csrc/autograd/python_enum_tag.h new file mode 100644 index 00000000000000..7a95cd98229922 --- /dev/null +++ b/torch/csrc/autograd/python_enum_tag.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace torch { + namespace autograd { + void initEnumTag(PyObject* module); +}} // namespace torch::autograd diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 1807dfddcab586..a135d79377f222 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1346,7 +1346,7 @@ void initJITBindings(PyObject* module) { return _get_operation_for_overload_or_packet( {op}, symbol, args, kwargs, true); }); - return func; + return py::make_tuple(func, py::cast(op->getTags().vec())); } } throw std::runtime_error("Found no matching operator overload"); diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index 438fcc5411bb80..c70261db52b280 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -160,6 +160,16 @@ struct TORCH_API Operator { }); } + c10::ArrayRef getTags() const { + return op_.fold>( + [](const C10Operator& op) { return op.handle_.getTags(); }, + [](const JitOnlyOperator& op) { + // Returns empty list of tags for JitOnlyOperators since it + // doesn't save c10::OperatorHandle + return c10::ArrayRef(); + }); + } + bool isC10Op() const { return op_.is_left(); } diff --git a/torch/library.h b/torch/library.h index 38887740ecdf66..b4e4fa7ffa53d7 100644 --- a/torch/library.h +++ b/torch/library.h @@ -65,6 +65,8 @@ // Just for inferFunctionSchemaFromFunctor #include +#include +#include namespace torch { @@ -594,12 +596,12 @@ class TORCH_API Library final { /// m.def("add(Tensor self, Tensor other) -> Tensor"); /// } /// ``` + template - Library& def(Schema&& raw_schema) & { + Library& def(Schema&& raw_schema, const std::vector& tags = {}) & { c10::FunctionSchema s = schema(std::forward(raw_schema)); - return _def(std::move(s)); + return _def(std::move(s), nullptr, std::move(tags)); } - /// Define an operator for a schema and then register an implementation for /// it. This is typically what you would use if you aren't planning /// on making use of the dispatcher to structure your operator @@ -813,7 +815,8 @@ class TORCH_API Library final { // public because we only implement & qualifier and not && qualifier Library& _def( c10::FunctionSchema&& schema, - c10::OperatorName* out_name = nullptr) &; + c10::OperatorName* out_name = nullptr, + const std::vector& tags = {}) &; Library& _def( c10::either&&, CppFunction&& f) &; diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1872c86827ef63..fbaf1059d1ad9e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11763,6 +11763,7 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), # RuntimeError: Sparse CSR tensors do not have strides. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), # RuntimeError: sampled_addmm: Expected result to have sparse csr layout, but got Strided DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'), # RuntimeError: Sparse CSR tensors do not have strides @@ -13876,8 +13877,10 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cpu'), DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)), - DecorateInfo(unittest.skip("Works on some conifgs"), 'TestCudaFuserOpInfo', - 'test_nvfuser_correctness', dtypes=(torch.bfloat16,)), + # RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. + # Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() + # to actually allocate memory + DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), ), sample_inputs_func=sample_inputs_max_pool), OpInfo('nn.functional.max_pool2d', @@ -17340,6 +17343,7 @@ def error_inputs_mean(op_info, device, **kwargs): # Allowed exception: sparse tensors don't have strides DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'), DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.skip("Allowed exception"), 'TestTags', 'test_tags'), # TODO: implement csr.to_sparse(sample_dim) where sampled_dim is 1. DecorateInfo(unittest.skip("csr.to_sparse(1) not implemented. Skipped!"), 'TestSparseCSR', 'test_sparse_csr_consistency'), diff --git a/torchgen/gen.py b/torchgen/gen.py index 5d18c07b96bcca..4e2e974160ecdd 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -152,6 +152,7 @@ def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] _GLOBAL_PARSE_NATIVE_YAML_CACHE = {} +_GLOBAL_PARSE_TAGS_YAML_CACHE = {} # Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices. ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"]) @@ -220,11 +221,13 @@ def parse_tags_yaml_struct(es: object, path: str = "") -> Set[str]: @functools.lru_cache(maxsize=None) def parse_tags_yaml(path: str) -> Set[str]: - # TODO: parse tags.yaml and create a tags database (a dict of tag name mapping to a Tag object) - with open(path, "r") as f: - es = yaml.load(f, Loader=LineLoader) - valid_tags = parse_tags_yaml_struct(es, path=path) - return valid_tags + global _GLOBAL_PARSE_TAGS_YAML_CACHE + if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE: + with open(path, "r") as f: + es = yaml.load(f, Loader=LineLoader) + _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path) + + return _GLOBAL_PARSE_TAGS_YAML_CACHE[path] def parse_native_yaml( @@ -234,7 +237,6 @@ def parse_native_yaml( *, skip_native_fns_gen: bool = False, ) -> ParsedYaml: - # TODO: parse tags.yaml and create a tags database (a dict of tag name mapping to a Tag object) global _GLOBAL_PARSE_NATIVE_YAML_CACHE if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE: valid_tags = parse_tags_yaml(tags_yaml_path) @@ -500,7 +502,8 @@ class RegisterSchema: def __call__(self, f: NativeFunction) -> Optional[str]: if not self.selector.is_native_function_selected(f): return None - return f"m.def({cpp_string(str(f.func))});\n" + tags = "{" + ", ".join([f"at::Tag::{tag}" for tag in f.tags]) + "}" + return f"m.def({cpp_string(str(f.func))}, {tags});\n" # Generates Operators.h and Operators.cpp. @@ -1713,6 +1716,7 @@ def gen_per_operator_headers( def gen_headers( *, native_functions: Sequence[NativeFunction], + valid_tags: Set[str], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], structured_native_functions: Sequence[NativeFunctionsGroup], static_dispatch_idx: List[BackendIndex], @@ -1840,6 +1844,11 @@ def gen_aten_interned_strings() -> Dict[str, str]: core_fm.write("aten_interned_strings.h", gen_aten_interned_strings) + def gen_tags_enum() -> Dict[str, str]: + return {"enum_of_valid_tags": (",\n".join([f"{tag}" for tag in valid_tags]))} + + core_fm.write("enum_tag.h", gen_tags_enum) + def gen_source_files( *, @@ -2396,6 +2405,7 @@ def main() -> None: del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)] parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys) + valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path] native_functions, backend_indices = ( parsed_yaml.native_functions, parsed_yaml.backend_indices, @@ -2501,6 +2511,7 @@ def main() -> None: if "headers" in options.generate: gen_headers( native_functions=native_functions, + valid_tags=valid_tags, grouped_native_functions=grouped_native_functions, structured_native_functions=structured_native_functions, static_dispatch_idx=static_dispatch_idx,