Skip to content

Commit

Permalink
Simplify device guard code generation (pytorch#55112)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#55112

Based on pytorch#47765
ghstack-source-id: 126114775

Test Plan: buck build //caffe2/aten/...

Reviewed By: ezyang

Differential Revision: D27487085

fbshipit-source-id: 157fcd19f538ce0c1e053e3e974b48bdb93a0226
  • Loading branch information
wenleix authored and facebook-github-bot committed Apr 9, 2021
1 parent 43ede4c commit c0379ac
Showing 1 changed file with 26 additions and 37 deletions.
63 changes: 26 additions & 37 deletions tools/codegen/dest/register_dispatch_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,52 +148,41 @@ def generate_defn(cpp_sig: CppSignature) -> str:

args_exprs_str = ', '.join(a.name for a in args)

return_kw = " return "

cuda_guard = ""
if is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key):
self_arg = [f.func.arguments.self_arg.argument] if f.func.arguments.self_arg is not None else []

# There is precedence for which argument we use to do
# device guard. This describes the precedence order.
candidate_args = itertools.chain(
self_arg,
f.func.arguments.out,
f.func.arguments.flat_positional
)

# Only tensor like arguments are eligible
device_of = next((f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None)
init_cuda = ""
device_guard = "// DeviceGuard omitted" # default

if (is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key)) and f.device_guard:
has_tensor_options = any(isinstance(a.argument, TensorOptionsArguments) for a in args)
if has_tensor_options:
# kernel is creating a tensor
device_guard = "const DeviceGuard device_guard(device_or_default(device));"

cuda_guard_from_tensor_options = """\
const DeviceGuard device_guard(device_or_default(device));
"""

# TODO: There is probably a simpler version of this that
# works just as well.
if f.device_guard and is_generic_dispatch_key(self.dispatch_key) and has_tensor_options:
cuda_guard = cuda_guard_from_tensor_options
elif f.device_guard and is_cuda_dispatch_key(self.dispatch_key) and has_tensor_options:
cuda_guard = f"""\
globalContext().lazyInitCUDA();
{cuda_guard_from_tensor_options}
"""
elif f.device_guard and device_of is not None:
cuda_guard = f"""\
const OptionalDeviceGuard device_guard(device_of({device_of}));
"""
if is_cuda_dispatch_key(self.dispatch_key):
# initialize CUDA on construction of CUDA tensors
init_cuda = "globalContext().lazyInitCUDA();\n"
else:
cuda_guard = """\
// DeviceGuard omitted
"""
# kernel is operating on existing tensors

# There is precedence for which argument we use to do
# device guard. This describes the precedence order.
self_arg = [f.func.arguments.self_arg.argument] if f.func.arguments.self_arg is not None else []
candidate_args = itertools.chain(
self_arg,
f.func.arguments.out,
f.func.arguments.flat_positional
)

# Only tensor like arguments are eligible
device_of = next((f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None)
if device_of is not None:
device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"

return f"""\
namespace {{
{returns_type} {name}({args_str}) {{
{cuda_guard}{return_kw}{impl_name}({args_exprs_str});
{init_cuda}{device_guard}
return {impl_name}({args_exprs_str});
}}
}} // anonymous namespace
Expand Down

0 comments on commit c0379ac

Please sign in to comment.