Skip to content

Commit

Permalink
Autogen Tags enum, and allow specifying tags while defining an op
Browse files Browse the repository at this point in the history
  • Loading branch information
anjali411 authored and pytorchmergebot committed Jun 3, 2022
1 parent 063c936 commit 9476a78
Show file tree
Hide file tree
Showing 24 changed files with 200 additions and 27 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ generated_cpu_cpp = [
"aten/src/ATen/NativeMetaFunctions.h",
"aten/src/ATen/RegistrationDeclarations.h",
"aten/src/ATen/core/aten_interned_strings.h",
"aten/src/ATen/core/enum_tag.h",
"aten/src/ATen/core/TensorBody.h",
"aten/src/ATen/core/TensorMethods.cpp",
"aten/src/ATen/core/ATenOpList.cpp",
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/core/dispatch/Dispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void Dispatcher::deregisterLibrary_(const std::string& ns) {
libraries_.erase(ns);
}

RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug) {
RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags) {
// we need a lock to avoid concurrent writes
std::lock_guard<std::mutex> lock(mutex_);

Expand All @@ -157,7 +157,7 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin
TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
" Each overload's schema should only be registered with a single call to def().",
" Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug());
op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug));
op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug), tags);
listeners_->callOnOperatorRegistered(op);

// NB: do not increment the counts until AFTER error checking
Expand Down
16 changes: 15 additions & 1 deletion aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <type_traits>

#include <ATen/core/grad_mode.h>
#include <ATen/core/enum_tag.h>

namespace c10 {

Expand Down Expand Up @@ -177,7 +178,7 @@ class TORCH_API Dispatcher final {
* If a schema with the same operator name and overload name already exists,
* this function will check that both schemas are exactly identical.
*/
RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug);
RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags = {});

/**
* Register a kernel to the dispatch table for an operator.
Expand Down Expand Up @@ -338,6 +339,19 @@ class TORCH_API OperatorHandle {
return operatorDef_->op.checkInvariants();
}

c10::ArrayRef<at::Tag> getTags() const {
return operatorDef_->op.getTags();
}

bool hasTag(const at::Tag& tag) const {
for(const auto& tag_: getTags()) {
if (tag == tag_) {
return true;
}
}
return false;
}

template<class FuncType>
TypedOperatorHandle<FuncType> typed() const {
// NB: This assert is not 100% sound: you can retrieve a typed() operator
Expand Down
8 changes: 7 additions & 1 deletion aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace {
OperatorEntry::OperatorEntry(OperatorName&& operator_name)
: name_(std::move(operator_name))
, schema_()
, tags_()
, dispatchTable_()
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
, kernels_()
Expand Down Expand Up @@ -57,7 +58,7 @@ const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const {
return kernel;
}

void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug) {
void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector<at::Tag> tags) {
TORCH_INTERNAL_ASSERT(!schema_.has_value());
for (const auto& kernel : kernels_) {
for (const auto &j : kernel.second) {
Expand All @@ -69,6 +70,7 @@ void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug)
// NB: don't register schema until after we've checked everything!
dispatchKeyExtractor_.registerSchema(schema);
schema_ = AnnotatedSchema(std::move(schema), std::move(debug));
tags_ = std::move(tags);
}

void OperatorEntry::deregisterSchema() {
Expand Down Expand Up @@ -208,6 +210,10 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat
return nullptr;
}

const std::vector<at::Tag>& OperatorEntry::getTags() const {
return tags_;
}

std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
// [Note] DispatchTable computation
// dispatchTable contains entries for runtime dispatch keys.
Expand Down
7 changes: 5 additions & 2 deletions aten/src/ATen/core/dispatch/OperatorEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/core/dispatch/CppSignature.h>
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
#include <ATen/core/enum_tag.h>

#include <list>
#include <array>
Expand Down Expand Up @@ -98,7 +99,7 @@ class TORCH_API OperatorEntry final {
// attempt to register a schema when one is already present or vice
// versa that is an error. (Refcounting for the registrations is
// handled in the OperatorHandle in Dispatcher)
void registerSchema(FunctionSchema&&, std::string&& debug);
void registerSchema(FunctionSchema&&, std::string&& debug, std::vector<at::Tag> tags = {});
void deregisterSchema();

const OperatorName& operator_name() const {
Expand Down Expand Up @@ -205,12 +206,14 @@ class TORCH_API OperatorEntry final {
bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
// Returns true if kernel_ has entry for a particular key.
bool hasKernelForDispatchKey(DispatchKey k) const;
// Returns all the operator tags added at the time of registration
const std::vector<at::Tag>& getTags() const;

private:

OperatorName name_;
c10::optional<AnnotatedSchema> schema_;

std::vector<at::Tag> tags_;
std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_;

Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/core/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, c
// merge everything

#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name) & {
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector<at::Tag>& tags) & {
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
DEF_PRELUDE,
"Cannot define an operator inside of a ", toString(kind_), " block. "
Expand Down Expand Up @@ -128,7 +128,8 @@ Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name
registrars_.emplace_back(
c10::Dispatcher::singleton().registerDef(
std::move(schema),
debugString(file_, line_)
debugString(file_, line_),
tags
)
);
return *this;
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/tags.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@
This tag indicates operators that are *_copy* variants
of view/aliasing operators. If an operator has a view_copy tag,
then it should have the name {op}_copy, where {op} is a view operator.
- tag: generated
desc: |
This tag indicates that the operator doesn't have an explicit entry in
native_functions.yaml, and instead was generated automatically by the codegen.
10 changes: 10 additions & 0 deletions aten/src/ATen/templates/enum_tag.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once

// ${generated_comment}

namespace at {
// Enum of valid tags obtained from the entries in tags.yaml
enum class Tag {
${enum_of_valid_tags}
};
}
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"):
"torch/csrc/autograd/generated/python_nn_functions.cpp",
"torch/csrc/autograd/generated/python_fft_functions.cpp",
"torch/csrc/autograd/generated/python_linalg_functions.cpp",
"torch/csrc/autograd/generated/python_enum_tag.cpp",
"torch/csrc/autograd/generated/python_return_types.cpp",
"torch/csrc/autograd/generated/python_sparse_functions.cpp",
"torch/csrc/autograd/generated/python_special_functions.cpp",
Expand Down
2 changes: 2 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_sparse_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
)

set(GENERATED_H_PYTHON
Expand Down Expand Up @@ -463,6 +464,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TOOLS_PATH}/autograd/templates/python_sparse_functions.cpp"
"${TOOLS_PATH}/autograd/templates/python_special_functions.cpp"
"${TOOLS_PATH}/autograd/templates/python_return_types.cpp"
"${TOOLS_PATH}/autograd/templates/python_enum_tag.cpp"
"${TOOLS_PATH}/autograd/templates/variable_factories.h"
"${TOOLS_PATH}/autograd/templates/annotated_fn_args.py.in"
"${TOOLS_PATH}/autograd/deprecated.yaml"
Expand Down
4 changes: 4 additions & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,10 @@ Utilities
vmap
_assert

Operator Tags
------------------------------------
.. autoclass:: Tag
:members:

.. Empty submodules added only for tracking.
.. py:module:: torch.contrib
Expand Down
54 changes: 54 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
instantiate_device_type_tests,
ops,
onlyCUDA,
onlyCPU,
onlyNativeDeviceTypes,
OpDTypes,
skipMeta,
Expand All @@ -56,6 +57,7 @@
from torch.testing._internal import composite_compliance

from torch.utils._pytree import tree_flatten
from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode

# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
torch.set_default_dtype(torch.float32)
Expand Down Expand Up @@ -1355,6 +1357,57 @@ def is_bit_set(x):
torch.is_complex,
)

# input strides and size may have been altered due to the result of an inplace op
def test_inplace_view(func, input, rs, input_size, input_strides):
if func is None:
return
# TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out
# which mutate not necessarily the first input.
if isinstance(rs, torch.Tensor) and rs is input:
unequal_size = rs.size() != input_size
unequal_strides = rs.stride() != input_strides
# resize_ should probably have inplace_view tag. Not adding the tag since it
# breaks some codegen logic
if (unequal_size or unequal_strides):
if isinstance(func, torch._ops.OpOverloadPacket):
func = func.default
# Reference: https://github.com/pytorch/pytorch/issues/78759
if func is not torch.ops.aten.resize_.default:
# TODO: use self.assertIn when we have separate tests for each tag
assert torch.Tag.inplace_view in func.tags

# A mode that when enabled runs correctness checks to ensure
# that operators have expected tags based on their input and
# ouput tensor properties
class TestTagsMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if isinstance(args[0], torch.Tensor):
old_size = args[0].size()
old_stride = args[0].stride()
rs = func(*args, **kwargs)
test_inplace_view(func, args[0], rs, old_size, old_stride)
else:
rs = func(*args, **kwargs)
return rs

# Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags`
class TestTags(TestCase):
@onlyCPU
@ops(ops_and_refs, dtypes=OpDTypes.any_one)
def test_tags(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=False)
for sample in samples:
# TODO: Test tags for ops that return a list of tensors
input = sample.input
if isinstance(input, torch.Tensor):
old_size = input.size()
old_stride = input.stride()
with push_torch_dispatch_mode(TestTagsMode):
rs = op(input, *sample.args, **sample.kwargs)
# TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
aten_name = op.aten_name if op.aten_name is not None else op.name
opoverloadpacket = getattr(torch.ops.aten, aten_name, None)
test_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)


class TestRefsOpsInfo(TestCase):
Expand Down Expand Up @@ -1394,6 +1447,7 @@ def test_refs_are_in_python_ref_db(self, op):
instantiate_device_type_tests(TestCompositeCompliance, globals())
instantiate_device_type_tests(TestMathBits, globals())
instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")
instantiate_device_type_tests(TestTags, globals())

if __name__ == "__main__":
run_tests()
4 changes: 4 additions & 0 deletions test/test_public_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ def test_no_new_bindings(self):
"vitals_enabled",

"wait",
"Tag",
"inplace_view",
"view_copy",
"generated"
}
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}

Expand Down
13 changes: 12 additions & 1 deletion tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
namedtuple_fieldnames,
signature,
)
from torchgen.gen import cpp_string, parse_native_yaml
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
from torchgen.context import with_native_function
from torchgen.model import (
Argument,
Expand Down Expand Up @@ -325,6 +325,17 @@ def gen(
fm, functions, lambda fn: True, "python_return_types.cpp"
)

valid_tags = parse_tags_yaml(tags_yaml_path)

def gen_tags_enum() -> Dict[str, str]:
return {
"enum_of_valid_tags": (
"".join([f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags])
)
}

fm.write("python_enum_tag.cpp", gen_tags_enum)


def group_filter_overloads(
pairs: Sequence[PythonSignatureNativeFunctionPair],
Expand Down
15 changes: 15 additions & 0 deletions tools/autograd/templates/python_enum_tag.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <torch/csrc/autograd/python_enum_tag.h>
#include <pybind11/pybind11.h>
#include <ATen/core/enum_tag.h>

namespace py = pybind11;
namespace torch {
namespace autograd {
void initEnumTag(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::enum_<at::Tag>(m, "Tag")
${enum_of_valid_tags}
.export_values();
m.doc() = "An Enum that contains tags that can be assigned to an operator registered in C++.";
}
}}
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _jit_init() -> _bool: ...
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ...
def _get_operation_overload(op_name: str, op_overload_name: str) -> Callable: ...
def _get_operation_overload(op_name: str, op_overload_name: str) -> Tuple[Callable, List[Any]]: ...
def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
optimization_blocklist: Set[MobileOptimizerType],
Expand Down
11 changes: 8 additions & 3 deletions torch/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ def dl_open_guard():
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
class OpOverload:
def __init__(self, overloadpacket, op, schema):
def __init__(self, overloadpacket, op, schema, tags):
self._op = op
self._schema = schema
self._overloadpacket = overloadpacket
self._tags = tags
self._overloadname = 'default' if schema.overload_name == '' else schema.overload_name
self.__name__ = "{}.{}".format(self._schema.name.split("::")[1], self._overloadname)
self.__module__ = overloadpacket.__module__
Expand Down Expand Up @@ -65,6 +66,10 @@ def overloadpacket(self):
def op(self):
return self._op

@property
def tags(self):
return self._tags

# TODO: add more methods to expose information about input and output arguments

# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
Expand Down Expand Up @@ -123,10 +128,10 @@ def __getattr__(self, key):
# This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
use_key = '' if key == 'default' else key
# TODO: disallow access to overloads registered by JIT
op_ = torch._C._get_operation_overload(
op_, tags = torch._C._get_operation_overload(
self._qualified_op_name, use_key)
schema = torch._C._get_schema(self._qualified_op_name, use_key)
overload = OpOverload(self, op_, schema)
overload = OpOverload(self, op_, schema, tags)
# cache the overload object
setattr(self, key, overload)
return overload
Expand Down
Loading

0 comments on commit 9476a78

Please sign in to comment.