Skip to content

Commit

Permalink
[pytorch][codegen] remove dead code in gen_variable_type.py (pytorch#…
Browse files Browse the repository at this point in the history
…47975)

Summary: Pull Request resolved: pytorch#47975

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D24976274

Pulled By: ljk53

fbshipit-source-id: 8542471ee30f26592aad949fc17eef87a47df024
  • Loading branch information
ljk53 authored and facebook-github-bot committed Nov 18, 2020
1 parent 07657b6 commit 5243456
Showing 1 changed file with 0 additions and 33 deletions.
33 changes: 0 additions & 33 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,36 +263,6 @@
#endif
""")

FACTORY_FUNCTION_NAMES = None

# TODO The maybe_unwrap_optional_tensors is only needed because our at::native::xxx functions
# still take "Tensor" instead of "optional<Tensor>", so we need CPUType, TypeDefault, ...
# to do the same. Once at::native::xxx are converted, we can remove use_optional_tensor
# and use the use_optional_tensor=True behavior always.
def maybe_unwrap_optional_tensors(option, formals, args):
assert len(formals) == len(args), \
"Assert we didn't screw up with method_args removing self but forgetting to remove it from formals"
if option['use_c10_dispatcher'] in ['full', 'hacky_wrapper_for_legacy_signatures']:
def maybe_unwrap_optional_tensor(formal, arg):
if formal['dynamic_type'] == 'Tensor' and formal['is_nullable']:
return "{}.has_value() ? *{} : at::Tensor()".format(arg, arg)
else:
return arg
return [maybe_unwrap_optional_tensor(formal, arg) for (formal, arg) in zip(formals, args)]
else:
assert option['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper'
return args


def find_factory_functions(declarations):
global FACTORY_FUNCTION_NAMES
FACTORY_FUNCTION_NAMES = set()

for declaration in declarations:
if declaration['is_factory_method']:
FACTORY_FUNCTION_NAMES.add(declaration['api_name'])


# Methods shared by TraceType and VariableType to handle return variable declaration, tie and tuple.
def format_return_variables(declaration):
name = declaration['name']
Expand Down Expand Up @@ -344,9 +314,6 @@ def gen_variable_type(out, aten_declarations, template_path):
compute the output. The grad_fn is attached to differentiable functions.
"""

# WARNING: this function call modifies global mutable state
find_factory_functions(aten_declarations)

aten_declarations = list(sorted(aten_declarations, key=lambda decl: decl['name']))

gen_variable_type_shard(out, aten_declarations, template_path, None, True)
Expand Down

0 comments on commit 5243456

Please sign in to comment.