Skip to content

Commit

Permalink
[pytorch] split out trace type generator and migrate to new codegen m…
Browse files Browse the repository at this point in the history
…odel (pytorch#47438)

Summary: Pull Request resolved: pytorch#47438

Test Plan: Imported from OSS

Reviewed By: bhosmer

Differential Revision: D24808211

Pulled By: ljk53

fbshipit-source-id: 44dfadf550a255c05aa201e54b48101aaf722885
  • Loading branch information
ljk53 authored and facebook-github-bot committed Nov 9, 2020
1 parent 499d2fa commit 4159191
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 400 deletions.
1 change: 1 addition & 0 deletions .jenkins/pytorch/codegen-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ python -m tools.pyi.gen_pyi \
# autograd codegen (called by torch codegen but can run independently)
python -m tools.autograd.gen_autograd \
"$OUT"/torch/share/ATen/Declarations.yaml \
aten/src/ATen/native/native_functions.yaml \
"$OUT"/autograd \
tools/autograd

Expand Down
11 changes: 9 additions & 2 deletions tools/autograd/gen_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def get_signature(name, params, call_args):
return declarations


def gen_autograd(aten_path, out, autograd_dir, operator_selector: SelectiveBuilder, disable_autograd=False):
def gen_autograd(aten_path, native_functions_path, out, autograd_dir, operator_selector: SelectiveBuilder, disable_autograd=False):
full_aten_decls = load_aten_declarations(aten_path)

def filter_decls(aten_decls, operator_selector):
Expand All @@ -243,6 +243,10 @@ def is_operator_selected_for_training(decl):
from .gen_variable_type import gen_variable_type
gen_variable_type(out, aten_decls, template_path)

from . import gen_trace_type
# operator filter not applied as tracing sources are excluded in selective build
gen_trace_type.gen_trace_type(out, native_functions_path, template_path)

# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_lib
gen_autograd_functions_lib(
Expand Down Expand Up @@ -297,12 +301,15 @@ def main():
description='Generate autograd C++ files script')
parser.add_argument('declarations', metavar='DECL',
help='path to Declarations.yaml')
parser.add_argument('native_functions', metavar='NATIVE',
help='path to native_functions.yaml')
parser.add_argument('out', metavar='OUT',
help='path to output directory')
parser.add_argument('autograd', metavar='AUTOGRAD',
help='path to autograd directory')
args = parser.parse_args()
gen_autograd(args.declarations, args.out, args.autograd,
gen_autograd(args.declarations, args.native_functions,
args.out, args.autograd,
SelectiveBuilder.get_nop_selector())


Expand Down
21 changes: 2 additions & 19 deletions tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@
from collections import defaultdict
import itertools
import re
from .gen_variable_type import DONT_RECORD_TRACE
from .gen_trace_type import should_trace
from .utils import write, is_tensor_method

from tools.codegen.code_template import CodeTemplate
from tools.codegen.api.python import *
from tools.codegen.gen import cpp_string, with_native_function
from tools.codegen.model import *

from typing import Dict, Optional, List, Any, Tuple, Set
from typing import Dict, Optional, List, Any, Tuple, Set, Sequence

#
# declarations blocklist
Expand Down Expand Up @@ -575,23 +575,6 @@ def emit_dispatch_case(
return emit_single_dispatch(
overload.signature, overload.base, namedtuple_typenames)

# Copied from 'gen_variable_type.should_trace()'.
# TODO: consolidate after migrating autograd codegen.
@with_native_function
def should_trace(f: NativeFunction) -> bool:
# Operations involving Storage or Type are not traceable at the moment
if any(str(arg.type) in {'Storage', 'Type', 'ConstQuantizerPtr'}
for arg in f.func.schema_order_arguments()):
return False
# We can't trace functions which don't have any Tensor or TensorList returns
if not any(r.type.is_tensor_like() for r in f.func.returns):
return False
name = cpp.name(f.func)
base_name = f.func.name.name.base
if base_name in DONT_RECORD_TRACE or name in DONT_RECORD_TRACE:
return False
return True

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Forward Declarations Codegen
Expand Down
Loading

0 comments on commit 4159191

Please sign in to comment.