Skip to content

Commit

Permalink
Dispatch factory functions on Type (#15093)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#15093

Needed for backend extensions.

Reviewed By: ezyang

Differential Revision: D13427897

fbshipit-source-id: d0b34b0072e597ae599bd3bc25356831d7a18d6a
  • Loading branch information
Roy Li authored and facebook-github-bot committed Feb 1, 2019
1 parent d29912f commit 64186e0
Showing 1 changed file with 14 additions and 25 deletions.
39 changes: 14 additions & 25 deletions aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
@@ -547,12 +547,12 @@ def __getitem__(self, x):
])


def device_guard(option, formals, is_factory_method, dispatch_options, dispatch_tensor):
def device_guard(option, formals, dispatch_options, dispatch_tensor):
# For factory methods the `DeviceGuard` is already in the template.
if option.get('device_guard', True):
if dispatch_options:
return 'const DeviceGuard device_guard({}.device());'.format(dispatch_options['name'])
if not is_factory_method and dispatch_tensor:
if dispatch_tensor:
return 'const OptionalDeviceGuard device_guard(device_of({}));'.format(dispatch_tensor)
return '// DeviceGuard omitted'

@@ -836,7 +836,7 @@ def process_option(option, output_options):
option['method_prefix_derived'] = '' if broadcast_arg is None else 's_'
if option['mode'] == 'TH':
option['device_guard'] = False
option['device_guard_declaration'] = device_guard(option, formals, False, False, dispatch_tensor)
option['device_guard_declaration'] = device_guard(option, formals, False, dispatch_tensor)

env = nested_dict(option, top_env)

@@ -1057,13 +1057,8 @@ def find_formal(formal_name, formals):
option['name'], ", ".join(option['method_formals_with_defaults']))

type_method_dispatch = option['type_method_definition_dispatch']
backend_dispatch = isinstance(type_method_dispatch, dict)

# We only dispatch via options if there is backend-specific dispatch
# (otherwise it's a factory function that can dispatch directly to the
# native function).
dispatch_options = (find_formal('TensorOptions', formals)
if backend_dispatch else None)
dispatch_options = find_formal('TensorOptions', formals)
# Only dispatch via tensor if there is no Options argument
dispatch_tensor = None if dispatch_options else find_dispatch_tensor(formals)

@@ -1081,8 +1076,7 @@ def find_formal(formal_name, formals):
check_methods_do_not_start_with_underscore(option['name'], is_method)

option['method_prefix_derived'] = ''
option['device_guard_declaration'] = device_guard(option, formals, is_factory_method,
dispatch_options, dispatch_tensor)
option['device_guard_declaration'] = device_guard(option, formals, dispatch_options, dispatch_tensor)

env = nested_dict(option, top_env)

@@ -1091,15 +1085,13 @@ def find_formal(formal_name, formals):
raise Exception("broadcasting is not yet supported for native functions, "
"but specified for function {}", option['name'])

# Factory methods are not dispatched over `Type`.
if not is_factory_method:
if option['extended_method']:
top_env['pure_virtual_extended_type_method_declarations'].append(
PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
else:
top_env['pure_virtual_type_method_declarations'].append(
PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
top_env['type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
if option['extended_method']:
top_env['pure_virtual_extended_type_method_declarations'].append(
PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
else:
top_env['pure_virtual_type_method_declarations'].append(
PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
top_env['type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
option['native_type_method_dispatch'] = type_method_dispatch

# Note [Abstract ATen methods]
@@ -1115,7 +1107,7 @@ def find_formal(formal_name, formals):
abstract = True
top_env['type_method_definitions'].append(
TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
elif not is_factory_method:
else:
body = TYPE_DEFINITION_BODY_NATIVE.substitute(env)
top_env['type_method_definitions'].append(
TYPE_METHOD_DEFINITION_CONCRETE.substitute(
@@ -1153,10 +1145,7 @@ def find_formal(formal_name, formals):
option['inferred_type'] = 'at::getNonVariableType(at::Backend::Undefined, at::ScalarType::Float)'
declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
top_env['function_declarations'].append(declaration.substitute(env))
if is_factory_method:
top_env['function_definitions'].append(FACTORY_DEFINITION.substitute(env))
else:
top_env['function_definitions'].append(FUNCTION_DEFINITION.substitute(env))
top_env['function_definitions'].append(FUNCTION_DEFINITION.substitute(env))
method_of.append('namespace')

output_options.append(OutputDeclaration(

0 comments on commit 64186e0

Please sign in to comment.