Skip to content

Commit

Permalink
Allow specifying tags for aten operators in native_functions.yaml
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#72549

Approved by: https://github.com/ezyang
  • Loading branch information
anjali411 authored and pytorchmergebot committed Mar 25, 2022
1 parent 79f91e6 commit 1dab71a
Show file tree
Hide file tree
Showing 30 changed files with 139 additions and 66 deletions.
1 change: 1 addition & 0 deletions .circleci/scripts/cpp_doc_push_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ cp torch/_utils_internal.py tools/shared
# Generate PyTorch files
time python tools/setup_helpers/generate_code.py \
--native-functions-path aten/src/ATen/native/native_functions.yaml \
--tags-path aten/src/ATen/native/tags.yaml \
--nn-path aten/src/

# Build the docs
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ jobs:
set -eux
time python3 -mtools.generate_torch_version --is_debug=false
time python3 -mtools.codegen.gen -s aten/src/ATen -d build/aten/src/ATen
time python3 -mtools.pyi.gen_pyi --native-functions-path aten/src/ATen/native/native_functions.yaml --deprecated-functions-path "tools/autograd/deprecated.yaml"
time python3 -mtools.pyi.gen_pyi --native-functions-path aten/src/ATen/native/native_functions.yaml --tags-path aten/src/ATen/native/tags.yaml --deprecated-functions-path "tools/autograd/deprecated.yaml"
- name: Run mypy
env:
MYPY_FORCE_COLOR: 1
Expand Down
3 changes: 3 additions & 0 deletions .jenkins/pytorch/codegen-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,22 @@ mkdir -p "$OUT"/pyi/torch/_C
mkdir -p "$OUT"/pyi/torch/nn
python -m tools.pyi.gen_pyi \
--native-functions-path aten/src/ATen/native/native_functions.yaml \
--tags-path aten/src/ATen/native/tags.yaml \
--deprecated-functions-path tools/autograd/deprecated.yaml \
--out "$OUT"/pyi

# autograd codegen (called by torch codegen but can run independently)
python -m tools.autograd.gen_autograd \
"$OUT"/torch/share/ATen/Declarations.yaml \
aten/src/ATen/native/native_functions.yaml \
aten/src/ATen/native/tags.yaml \
"$OUT"/autograd \
tools/autograd

# annotated_fn_args codegen (called by torch codegen but can run independently)
mkdir -p "$OUT"/annotated_fn_args
python -m tools.autograd.gen_annotated_fn_args \
aten/src/ATen/native/native_functions.yaml \
aten/src/ATen/native/tags.yaml \
"$OUT"/annotated_fn_args \
tools/autograd
5 changes: 3 additions & 2 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ py_binary(
],
)

aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + glob(["aten/src/ATen/templates/**"])
aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + ["aten/src/ATen/native/tags.yaml"] + glob(["aten/src/ATen/templates/**"])

generated_cpu_cpp = [
"aten/src/ATen/RegisterBackendSelect.cpp",
Expand Down Expand Up @@ -185,6 +185,7 @@ genrule(
name = "all_generated_code",
srcs = [
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"aten/src/ATen/native/ts_native_functions.yaml",
"torch/csrc/lazy/core/shape_inference.h",
"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
Expand All @@ -194,7 +195,7 @@ genrule(
"aten/src/ATen/templates/LazyIr.h",
],
outs = libtorch_cpp_generated_sources + libtorch_python_generated_sources,
cmd = "$(location :generate_code) --install_dir `dirname $(location torch/csrc/autograd/generated/variable_factories.h)`/../.. --native-functions-path $(location aten/src/ATen/native/native_functions.yaml) --nn-path aten/src --gen_lazy_ts_backend",
cmd = "$(location :generate_code) --install_dir `dirname $(location torch/csrc/autograd/generated/variable_factories.h)`/../.. --native-functions-path $(location aten/src/ATen/native/native_functions.yaml) --tags-path $(location aten/src/ATen/native/tags.yaml) --nn-path aten/src --gen_lazy_ts_backend",
tools = [":generate_code"],
)

Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/tags.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# This yaml file contains all the possible tags that can be defined in `tags` in `native_functions.yaml`

- tag: inplace_view
desc: |
This tag indicates if an operator *only* modifies the tensor metadata
2 changes: 2 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
COMMAND
"${PYTHON_EXECUTABLE}" tools/setup_helpers/generate_code.py
--native-functions-path "aten/src/ATen/native/native_functions.yaml"
--tags-path "aten/src/ATen/native/tags.yaml"
--nn-path "aten/src"
$<$<BOOL:${INTERN_DISABLE_AUTOGRAD}>:--disable-autograd>
$<$<BOOL:${SELECTED_OP_LIST}>:--selected-op-list-path="${SELECTED_OP_LIST}">
Expand All @@ -425,6 +426,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${GEN_PER_OPERATOR_FLAG}
DEPENDS
"${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml"
"${TORCH_ROOT}/aten/src/ATen/native/tags.yaml"
"${TORCH_ROOT}/aten/src/ATen/native/ts_native_functions.yaml"
"${TORCH_ROOT}/torch/csrc/lazy/core/shape_inference.h"
"${TORCH_ROOT}/torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
Expand Down
2 changes: 2 additions & 0 deletions cmake/Codegen.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ if(INTERN_BUILD_ATEN_OPS)
COMMAND ${GEN_UNBOXING_COMMAND_sources}
DEPENDS ${all_unboxing_script} ${sources_templates}
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/tags.yaml
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/..
)
else() # Otherwise do not generate or include sources into build.
Expand Down Expand Up @@ -210,6 +211,7 @@ if(INTERN_BUILD_ATEN_OPS)
COMMAND ${GEN_COMMAND_${gen_type}}
DEPENDS ${all_python} ${${gen_type}_templates}
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/tags.yaml
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/..
)
endforeach()
Expand Down
1 change: 1 addition & 0 deletions docs/cpp/source/check-doxygen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ python -m tools.codegen.gen

python tools/setup_helpers/generate_code.py \
--native-functions-path aten/src/ATen/native/native_functions.yaml \
--tags-path aten/src/ATen/native/tags.yaml \
--nn-path aten/src

popd
Expand Down
9 changes: 6 additions & 3 deletions tools/autograd/gen_annotated_fn_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
python -m tools.autograd.gen_annotated_fn_args \
aten/src/ATen/native/native_functions.yaml \
aten/src/ATen/native/tags.yaml \
$OUTPUT_DIR \
tools/autograd
Expand All @@ -29,8 +30,8 @@
is_py_nn_function, is_py_linalg_function, is_py_variable_method, is_py_special_function, \
is_py_fft_function

def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:
native_functions = parse_native_yaml(native_yaml_path).native_functions
def gen_annotated(native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str) -> None:
native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
mappings = (
(is_py_torch_function, 'torch._C._VariableFunctions'),
(is_py_nn_function, 'torch._C._nn'),
Expand Down Expand Up @@ -77,12 +78,14 @@ def main() -> None:
description='Generate annotated_fn_args script')
parser.add_argument('native_functions', metavar='NATIVE',
help='path to native_functions.yaml')
parser.add_argument('tags', metavar='TAGS',
help='path to tags.yaml')
parser.add_argument('out', metavar='OUT',
help='path to output directory')
parser.add_argument('autograd', metavar='AUTOGRAD',
help='path to template directory')
args = parser.parse_args()
gen_annotated(args.native_functions, args.out, args.autograd)
gen_annotated(args.native_functions, args.tags, args.out, args.autograd)

if __name__ == '__main__':
main()
23 changes: 15 additions & 8 deletions tools/autograd/gen_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
python -m tools.autograd.gen_autograd \
build/aten/src/ATen/Declarations.yaml \
aten/src/ATen/native/native_functions.yaml \
aten/src/ATen/native/tags.yaml \
$OUTPUT_DIR \
tools/autograd
Expand Down Expand Up @@ -41,28 +42,30 @@

def gen_autograd(
native_functions_path: str,
tags_path: str,
out: str,
autograd_dir: str,
operator_selector: SelectiveBuilder,
disable_autograd: bool = False,
) -> None:
# Parse and load derivatives.yaml
differentiability_infos = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path,
tags_path)

template_path = os.path.join(autograd_dir, 'templates')

native_funcs = parse_native_yaml(native_functions_path).native_functions
native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions
fns = list(sorted(filter(
operator_selector.is_native_function_selected_for_training,
native_funcs), key=lambda f: cpp.name(f.func)))
fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo] = match_differentiability_info(fns, differentiability_infos)

# Generate VariableType.h/cpp
if not disable_autograd:
gen_variable_type(out, native_functions_path, fns_with_diff_infos, template_path)
gen_variable_type(out, native_functions_path, tags_path, fns_with_diff_infos, template_path)

gen_inplace_or_view_type(out, native_functions_path, fns_with_diff_infos, template_path)
gen_inplace_or_view_type(out, native_functions_path, tags_path, fns_with_diff_infos, template_path)

# operator filter not applied as tracing sources are excluded in selective build
gen_trace_type(out, native_funcs, template_path)
Expand All @@ -71,16 +74,18 @@ def gen_autograd(
out, differentiability_infos, template_path)

# Generate variable_factories.h
gen_variable_factories(out, native_functions_path, template_path)
gen_variable_factories(out, native_functions_path, tags_path, template_path)


def gen_autograd_python(
native_functions_path: str,
tags_path: str,
out: str,
autograd_dir: str,
) -> None:
differentiability_infos = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path,
tags_path)

template_path = os.path.join(autograd_dir, 'templates')

Expand All @@ -91,20 +96,22 @@ def gen_autograd_python(
# Generate Python bindings
deprecated_path = os.path.join(autograd_dir, 'deprecated.yaml')
gen_python_functions.gen(
out, native_functions_path, deprecated_path, template_path)
out, native_functions_path, tags_path, deprecated_path, template_path)


def main() -> None:
parser = argparse.ArgumentParser(
description='Generate autograd C++ files script')
parser.add_argument('native_functions', metavar='NATIVE',
help='path to native_functions.yaml')
parser.add_argument('tags', metavar='TAGS',
help='path to tags.yaml')
parser.add_argument('out', metavar='OUT',
help='path to output directory')
parser.add_argument('autograd', metavar='AUTOGRAD',
help='path to autograd directory')
args = parser.parse_args()
gen_autograd(args.native_functions,
gen_autograd(args.native_functions, args.tags,
args.out, args.autograd,
SelectiveBuilder.get_nop_selector())

Expand Down
1 change: 1 addition & 0 deletions tools/autograd/gen_inplace_or_view_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def gen_inplace_or_view_type_env(fn: NativeFunctionWithDifferentiabilityInfo) ->
def gen_inplace_or_view_type(
out: str,
native_yaml_path: str,
tags_yaml_path: str,
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
template_path: str
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ def is_py_special_function(f: NativeFunction) -> bool:
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
def gen(out: str, native_yaml_path: str, tags_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
native_functions = parse_native_yaml(native_yaml_path).native_functions
native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
native_functions = list(filter(should_generate_py_binding, native_functions))

methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/gen_variable_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def maybe_optional_type(type: str, is_opt: bool) -> str:
qualified_type = f'{argument_type[:index]}at::{argument_type[index:]}'
return maybe_optional_type(qualified_type, is_opt)

def gen_variable_factories(out: str, native_yaml_path: str, template_path: str) -> None:
native_functions = parse_native_yaml(native_yaml_path).native_functions
def gen_variable_factories(out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str) -> None:
native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
fm.write_with_template('variable_factories.h', 'variable_factories.h', lambda: {
Expand Down
1 change: 1 addition & 0 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@
def gen_variable_type(
out: str,
native_yaml_path: str,
tags_yaml_path: str,
fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo],
template_path: str,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/load_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

_GLOBAL_LOAD_DERIVATIVE_CACHE = {}

def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Sequence[DifferentiabilityInfo]:
def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str) -> Sequence[DifferentiabilityInfo]:
# Do some caching as this is a deterministic function
global _GLOBAL_LOAD_DERIVATIVE_CACHE
key = (derivatives_yaml_path, native_yaml_path)
Expand All @@ -30,7 +30,7 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque
with open(derivatives_yaml_path, 'r') as f:
definitions = yaml.load(f, Loader=YamlLoader)

functions = parse_native_yaml(native_yaml_path).native_functions
functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions

# What's the difference between function schema v.s. signature?
# function schema is the complete declaration including mutability annotation / default value and etc.
Expand Down
4 changes: 2 additions & 2 deletions tools/codegen/api/functionalization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from tools.codegen.model import (
FunctionSchema, BaseTy, BaseType, NativeFunction, Argument, Tag,
FunctionSchema, BaseTy, BaseType, NativeFunction, Argument,
)
from tools.codegen.api.types import (
Binding, NamedCType, ConstRefCType, BaseCType, CType, tensorT, longT
Expand Down Expand Up @@ -44,7 +44,7 @@
# The name returned here corresponds to the name of the inner function called by the lambda.
def name(f: NativeFunction, *, functional_op: NativeFunction, is_reverse: bool, include_namespace: bool) -> str:
# For inplace_view ops, the lambda calls out to the corresponding functional view op
fn = functional_op if f.tag is Tag.inplace_view else f
fn = functional_op if 'inplace_view' in f.tags else f
name = fn.func.name.unambiguous_name()
if is_reverse:
# in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
Expand Down
Loading

0 comments on commit 1dab71a

Please sign in to comment.