Skip to content

Commit

Permalink
[pytorch][codegen] add autograd data model (pytorch#48249)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#48249

Introduced autograd related data models at tools.codegen.api.autograd.

Migrated load_derivatives.py to produce the new data models from derivatives.yaml.
It has clean mypy-strict result.

Changed both gen_autograd_functions.py and gen_variable_type.py to consume
the new data model.

Added type annotations to gen_autograd_functions.py - it has clean mypy-strict
result except for the .gen_autograd import (so haven't added it to the strict
config in this PR).

To limit the scope of the PR, gen_variable_type.py is not refactored, and the
main structure of load_derivatives.py / gen_autograd_functions.py is kept. We
only make necessary changes to make it work.

Confirmed byte-for-byte compatible with the old codegen:

```
Run it before and after this PR:
  .jenkins/pytorch/codegen-test.sh <baseline_output_dir>
  .jenkins/pytorch/codegen-test.sh <test_output_dir>

Then run diff to compare the generated files:
  diff -Naur <baseline_output_dir> <test_output_dir>
```

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D25086561

Pulled By: ljk53

fbshipit-source-id: 1f43ab0931d9814c24683b9a48ca497c5fc3d729
  • Loading branch information
ljk53 authored and facebook-github-bot committed Nov 20, 2020
1 parent fa41275 commit de284b6
Show file tree
Hide file tree
Showing 9 changed files with 583 additions and 456 deletions.
1 change: 1 addition & 0 deletions mypy-strict.ini
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ files = tools/codegen/gen.py,
tools/autograd/gen_python_functions.py,
tools/autograd/gen_trace_type.py,
tools/autograd/gen_variable_factories.py,
tools/autograd/load_derivatives.py,
torch/utils/benchmark/utils/common.py,
torch/utils/benchmark/utils/timer.py,
torch/utils/benchmark/utils/valgrind_wrapper/*.py,
Expand Down
20 changes: 7 additions & 13 deletions tools/autograd/gen_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,15 @@ def is_operator_selected_for_training(decl):

# Parse and load derivatives.yaml
from .load_derivatives import load_derivatives
autograd_functions = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), full_aten_decls)
differentiability_infos = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)

template_path = os.path.join(autograd_dir, 'templates')

# Generate VariableType.h/cpp
if not disable_autograd:
from .gen_variable_type import gen_variable_type
gen_variable_type(out, aten_decls, template_path)
gen_variable_type(out, aten_decls, differentiability_infos, template_path)

from . import gen_trace_type
# operator filter not applied as tracing sources are excluded in selective build
Expand All @@ -186,30 +186,24 @@ def is_operator_selected_for_training(decl):
# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_lib
gen_autograd_functions_lib(
out, autograd_functions, template_path)
out, differentiability_infos, template_path)

# Generate variable_factories.h
from .gen_variable_factories import gen_variable_factories
# Some non-selectable ops (e.g. prim ops) need factory methods so we pass in `full_aten_decls` here.
gen_variable_factories(out, native_functions_path, template_path)


def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir):
# TODO Deduplicate these four variable assignments

aten_decls = load_aten_declarations(aten_path)

# Parse and load derivatives.yaml
from .load_derivatives import load_derivatives
autograd_functions = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), aten_decls)
differentiability_infos = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)

template_path = os.path.join(autograd_dir, 'templates')

# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_python
gen_autograd_functions_python(
out, autograd_functions, template_path)
out, differentiability_infos, template_path)

# Generate Python bindings
from . import gen_python_functions
Expand Down
253 changes: 134 additions & 119 deletions tools/autograd/gen_autograd_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@
# Functions.h/cpp: subclasses of autograd::Node
# python_functions.h/cpp: Python bindings for the above classes
#
import os
import re
from .utils import nested_dict, CodeTemplate, write
from .gen_autograd import VIEW_FUNCTIONS
from .utils import IDENT_REGEX

from typing import List, Sequence, Tuple, Optional

from tools.codegen.api.autograd import *
from tools.codegen.api.types import *
from tools.codegen.code_template import CodeTemplate
from tools.codegen.gen import FileManager
from tools.codegen.model import *
from tools.codegen.utils import *

FUNCTION_DECLARATION = CodeTemplate("""\
struct TORCH_API ${op} : public ${superclass} {
Expand Down Expand Up @@ -84,142 +90,143 @@
# TODO: This is probably not exhaustive, but it's a start
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS


def gen_autograd_functions_lib(out, autograd_functions, template_path):
gen_autograd_functions(out, autograd_functions, template_path, "Functions")


def gen_autograd_functions_python(out, autograd_functions, template_path):
gen_autograd_functions(out, autograd_functions, template_path, "python_functions")


def gen_autograd_functions(out, autograd_functions, template_path, file_basename):
def gen_autograd_functions_lib(
out: str,
differentiability_infos: Sequence[DifferentiabilityInfo],
template_path: str,
) -> None:
gen_autograd_functions(out, differentiability_infos, template_path, "Functions")

def gen_autograd_functions_python(
out: str,
differentiability_infos: Sequence[DifferentiabilityInfo],
template_path: str,
) -> None:
gen_autograd_functions(out, differentiability_infos, template_path, "python_functions")

def gen_autograd_functions(
out: str,
differentiability_infos: Sequence[DifferentiabilityInfo],
template_path: str,
file_basename: str,
) -> None:
"""Functions.h and Functions.cpp body
These contain the auto-generated subclasses of torch::autograd::Node
for each every differentiable torch function.
"""

function_definitions = []
function_declarations = []
py_function_initializers = []

for func in autograd_functions:
env = process_function(func)

function_declarations.append(FUNCTION_DECLARATION.substitute(env))
function_definitions.append(FUNCTION_DEFINITION.substitute(env))
py_function_initializers.append(PY_FUNCTION_DEFINITION.substitute(env))

top_env = {
'autograd_function_definitions': function_definitions,
'autograd_function_declarations': function_declarations,
'py_function_initializers': py_function_initializers,
}

for suffix in [".h", ".cpp"]:
f = file_basename + suffix
templated_output = CodeTemplate.from_file(os.path.join(template_path, f))
write(out, f, templated_output, top_env)


def process_function(func):
env = {}
saved_variables = []
release_variables = []
saved_list_sizes = []
unpack = []
asserts = []

env['compute_index_ranges'] = []
for arg in func['args_with_derivatives']:
if arg['type'] == 'TensorList':
size = '{}_size_'.format(arg['name'])
saved_list_sizes.append('size_t {}_size_;'.format(arg['name']))
# only create an autograd function if we are actually going to calculate a derivative
infos = list(filter(lambda info: info.args_with_derivatives, differentiability_infos))
declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos))
definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos))
py_function_initializers = list(map(lambda f: process_function(f, PY_FUNCTION_DEFINITION), infos))

fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
for suffix in ['.h', '.cpp']:
fname = file_basename + suffix
fm.write_with_template(fname, fname, lambda: {
'generated_comment': '@' + f'generated from {fm.template_dir}/' + fname,
'autograd_function_declarations': declarations,
'autograd_function_definitions': definitions,
'py_function_initializers': py_function_initializers,
})

def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
saved_variables: List[str] = []
release_variables: List[str] = []
saved_list_sizes: List[str] = []
unpack: List[str] = []
asserts: List[str] = []
compute_index_ranges: List[str] = []

for arg in info.args_with_derivatives:
if arg.type == 'TensorList':
size = f'{arg.name}_size_'
saved_list_sizes.append(f'size_t {arg.name}_size_;')
else:
size = '1'
env['compute_index_ranges'].append('auto {}_ix = gen.range({});'.format(arg['name'], size))

def save_arg(arg, is_output):
name = arg['name']

if arg['type'] == 'Tensor' or arg['type'] == 'c10::optional<Tensor>' or arg['type'] == 'c10::optional<Tensor>&' or \
(arg['type'] == 'Scalar' and is_output):
saved_variables.append('SavedVariable {}_;'.format(name))
release_variables.append('{}_.reset_data();'.format(name))
release_variables.append('{}_.reset_grad_function();'.format(name))
compute_index_ranges.append(f'auto {arg.name}_ix = gen.range({size});')

def save_var(var: SavedAttribute, is_output: bool) -> None:
name = var.name
if var.type == 'Tensor' or var.type == 'c10::optional<Tensor>' or var.type == 'c10::optional<Tensor>&' or \
(var.type == 'Scalar' and is_output):
saved_variables.append(f'SavedVariable {name}_;')
release_variables.append(f'{name}_.reset_data();')
release_variables.append(f'{name}_.reset_grad_function();')
ptr = 'shared_from_this()' if is_output else ''
unpack.append('auto {} = {}_.unpack({});'.format(name, name, ptr))
elif arg['type'] == 'TensorList':
saved_variables.append('std::vector<SavedVariable> {}_;'.format(name))
saved_variables.append('bool {}_released_ = false;'.format(name))
unpack.append(f'auto {name} = {name}_.unpack({ptr});')
elif var.type == 'TensorList':
saved_variables.append(f'std::vector<SavedVariable> {name}_;')
saved_variables.append(f'bool {name}_released_ = false;')
# Just clear() is sufficient, we don't need to loop and clear each variable.
# Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
release_variables.append('{}_.clear();'.format(name))
release_variables.append('{}_released_ = true;'.format(name))
unpack.append('auto {} = unpack_list({}_);'.format(name, name))
asserts.append('TORCH_CHECK(!{}_released_, ERR_BACKWARD_TWICE);'.format(name))
elif arg['type'] == 'IntArrayRef':
saved_variables.append('std::vector<int64_t> {};'.format(name))
elif arg['type'] == 'c10::optional<IntArrayRef>':
saved_variables.append('c10::OptionalArray<int64_t> {};'.format(name))
elif arg['type'] == 'c10::optional<ArrayRef<double>>':
saved_variables.append('c10::OptionalArray<double> {};'.format(name))
elif arg['type'] == 'int64_t':
saved_variables.append('{} {} = 0;'.format(arg['type'], name))
release_variables.append(f'{name}_.clear();')
release_variables.append(f'{name}_released_ = true;')
unpack.append(f'auto {name} = unpack_list({name}_);')
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
elif var.type == 'IntArrayRef':
saved_variables.append(f'std::vector<int64_t> {name};')
elif var.type == 'c10::optional<IntArrayRef>':
saved_variables.append(f'c10::OptionalArray<int64_t> {name};')
elif var.type == 'c10::optional<ArrayRef<double>>':
saved_variables.append(f'c10::OptionalArray<double> {name};')
elif var.type == 'int64_t':
saved_variables.append(f'{var.type} {name} = 0;')
else:
saved_variables.append('{} {};'.format(arg['type'], name))
saved_variables.append(f'{var.type} {name};')

for arg in func['saved_inputs']:
save_arg(arg, is_output=False)
for arg in func['saved_outputs']:
save_arg(arg, is_output=True)
env['saved_variables'] = saved_variables
env['release_variables'] = release_variables
env['saved_list_sizes'] = saved_list_sizes
env['asserts'] = asserts
for var in info.all_saved_inputs:
save_var(var, is_output=False)
for var in info.all_saved_outputs:
save_var(var, is_output=True)

# lock the mutex when we release variables and in Node::apply to protect thread safety
# see Note [Thread Safety on Autograd Node]
if len(release_variables) > 0:
env['thread_lock'] = "std::lock_guard<std::mutex> lock(mutex_);"
thread_lock = 'std::lock_guard<std::mutex> lock(mutex_);'
else:
env['thread_lock'] = ''
thread_lock = ''

if uses_retain_variables(func):
env['will_release_variables'] = WILL_RELEASE_VARIABLES.substitute()
if uses_retain_variables(info):
will_release_variables = WILL_RELEASE_VARIABLES.substitute()
else:
env['will_release_variables'] = ''
will_release_variables = ''

body = []
body: List[str] = []

if uses_single_grad(func):
if uses_single_grad(info):
body.append('auto& grad = grads[0];')

def emit_derivative(derivative, args_with_derivatives):
formula = derivative['formula']
var_names = derivative['var_names']
def emit_derivative(
derivative: Derivative,
args_with_derivatives: Sequence[CppArgument],
) -> Tuple[bool, str]:
formula = derivative.formula
var_names = derivative.var_names
if len(var_names) == 1:
checks_any_grad_defined = False
if 'not_implemented' not in formula:
matching_args = [
arg for arg in args_with_derivatives
if ('name' in arg) and (arg['name'] == var_names[0])]
if arg.name == var_names[0]]
if len(matching_args) == 1:
# We can add undefined grad support if the input variable is a Tensor
if ('simple_type' in matching_args[0].keys()) and (matching_args[0]['simple_type'] == 'Tensor'):
arg = matching_args[0]
if isinstance(arg.argument, Argument) and str(arg.argument.type) == 'Tensor':
formula = 'any_grad_defined ? (' + formula + ') : Tensor()'
checks_any_grad_defined = True
return (checks_any_grad_defined,
DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula))
else:
if 'grad_input_mask' in formula:
masks = ['should_compute_output({{ {}_ix }}),'.format(n) for n in var_names]
masks = [f'should_compute_output({{ {n}_ix }}),' for n in var_names]
grad_input_mask = GRAD_INPUT_MASK.substitute(masks=masks, n=len(var_names))
else:
grad_input_mask = ''
idx_ranges = ', '.join("{}_ix".format(n) for n in var_names)
copy_ranges = []
idx_ranges = ', '.join(f'{n}_ix' for n in var_names)
copy_ranges: List[str] = []
for i, n in enumerate(var_names):
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
return False, DERIVATIVE_MULTI.substitute(
Expand All @@ -229,37 +236,45 @@ def emit_derivative(derivative, args_with_derivatives):

body.extend(unpack)
need_any_grad_defined_var = False
for derivative in func['derivatives']:
checks_any_grad_defined, derivative_text = emit_derivative(derivative, func['args_with_derivatives'])
for derivative in info.derivatives:
checks_any_grad_defined, derivative_text = emit_derivative(derivative, info.args_with_derivatives)
body.append(derivative_text)
need_any_grad_defined_var |= checks_any_grad_defined
# Since single-output derivative formulas need to check if grads are
# defined, only perform the check once, before all the formulas
if need_any_grad_defined_var:
body.insert(-len(func['derivatives']),
body.insert(-len(info.derivatives),
'bool any_grad_defined = any_variable_defined(grads);')

env['body'] = body
if func['name'] in UNTRACEABLE_FUNCTIONS:
env['superclass'] = 'Node'
if info.name in UNTRACEABLE_FUNCTIONS:
superclass = 'Node'
else:
env['superclass'] = 'TraceableFunction'
return nested_dict(env, func)


def uses_ident(func, ident):
if func is None:
superclass = 'TraceableFunction'

return template.substitute(
op=info.op,
compute_index_ranges=compute_index_ranges,
saved_variables=saved_variables,
release_variables=release_variables,
saved_list_sizes=saved_list_sizes,
asserts=asserts,
thread_lock=thread_lock,
will_release_variables=will_release_variables,
body=body,
superclass=superclass,
)

def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
if info is None:
return False
for derivative in func['derivatives']:
formula = derivative['formula']
for derivative in info.derivatives:
formula = derivative.formula
if re.search(IDENT_REGEX.format(ident), formula):
return True
return False

def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
return uses_ident(info, 'retain_variables')

def uses_retain_variables(func):
return uses_ident(func, 'retain_variables')


def uses_single_grad(func):
return uses_ident(func, 'grad')
def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
return uses_ident(info, 'grad')
Loading

0 comments on commit de284b6

Please sign in to comment.