Skip to content

Commit

Permalink
[pytorch] simplify tensor options logic in pybinding codegen (pytorch…
Browse files Browse the repository at this point in the history
…#46976)

Summary:
Pull Request resolved: pytorch#46976

Technically, it's not semantic preserving, e.g.: emition of
'requires_grad' is no longer gated by 'has_tensor_return' - there is no
guarantee that is_like_or_new_function should all have tensor return.
But the output is identical so there might be some invariant - could
also add assertion to fail loudly when it's broken.

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D24589211

Pulled By: ljk53

fbshipit-source-id: 47c7e43b080e4e67a526fde1a8a53aae99df4432
  • Loading branch information
ljk53 authored and facebook-github-bot committed Oct 29, 2020
1 parent a86b343 commit 79474a1
Showing 1 changed file with 6 additions and 19 deletions.
25 changes: 6 additions & 19 deletions tools/codegen/api/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,42 +515,31 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature:
has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns)

name: str = cpp.name(f.func)
has_options_arg = has_tensor_options(f)

is_like_function = name.endswith('_like') or f.category_override == 'like'
is_new_function = name.startswith('new_') or f.category_override == 'new'
is_factory_function = has_tensor_return and not has_tensor_input_arg \
or f.category_override == 'factory'
is_like_or_new_function_with_options = \
(is_like_function or is_new_function) and has_options_arg
is_factory_function = f.category_override == 'factory' or (has_tensor_return and not has_tensor_input_arg)
is_like_or_new_function = f.category_override in ('new', 'like') or name.startswith('new_') or name.endswith('_like')

tensor_options_args: List[PythonArgument] = []
if is_factory_function or has_options_arg:
if is_factory_function or is_like_or_new_function:
tensor_options_args.append(PythonArgument(
name='dtype',
cpp_type_str='const ScalarType &',
type=BaseType(BaseTy.ScalarType),
default=_dtype_default_type_hack(name),
default_init='self.scalar_type()'
if is_like_or_new_function_with_options else None,
default_init='self.scalar_type()' if is_like_or_new_function else None,
))

if is_factory_function or is_like_or_new_function_with_options:
tensor_options_args.append(PythonArgument(
name='layout',
cpp_type_str='c10::optional<Layout>',
type=BaseType(BaseTy.Layout),
default='torch.strided',
default_init='layout_from_backend(self.options().backend())'
if is_like_or_new_function_with_options else None,
default_init='layout_from_backend(self.options().backend())' if is_like_or_new_function else None,
))
tensor_options_args.append(PythonArgument(
name='device',
cpp_type_str='const Device &',
type=BaseType(BaseTy.Device),
default='None',
default_init='self.device()'
if is_like_or_new_function_with_options else None,
default_init='self.device()' if is_like_or_new_function else None,
))
tensor_options_args.append(PythonArgument(
name='pin_memory',
Expand All @@ -559,8 +548,6 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature:
default='False',
default_init=None,
))

if has_tensor_return and (is_factory_function or is_like_function or is_new_function):
tensor_options_args.append(PythonArgument(
name='requires_grad',
cpp_type_str='bool',
Expand Down

0 comments on commit 79474a1

Please sign in to comment.