From 55be35ae395edc5e880b0d8902736ee580b313bd Mon Sep 17 00:00:00 2001 From: Antonio Kim Date: Wed, 18 May 2022 17:58:47 +0000 Subject: [PATCH] Fix 'Code below assumes there is at least one tensor arg' assumption (#76917) Previously when codegening ops like `zeros_` or `ones_` we'd hit a `Code below assumes there is at least one tensor arg error`. This check is not entirely correct which is what is causing the error to be thrown. There are ops like the ones mentioned that pass in a `device` parameter that can be used in place of the "first tensor". CC: @wconstab @desertfire @henrytwo @ke1337 Pull Request resolved: https://github.com/pytorch/pytorch/pull/76917 Approved by: https://github.com/desertfire --- torch/csrc/lazy/backend/backend_device.cpp | 8 ++++++ torch/csrc/lazy/backend/backend_device.h | 1 + torchgen/api/lazy.py | 5 ++-- torchgen/dest/lazy_ir.py | 30 +++++++++++++++------- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/torch/csrc/lazy/backend/backend_device.cpp b/torch/csrc/lazy/backend/backend_device.cpp index 043ee486b1c4fa..8445fd74c9740d 100644 --- a/torch/csrc/lazy/backend/backend_device.cpp +++ b/torch/csrc/lazy/backend/backend_device.cpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace torch { namespace lazy { @@ -67,6 +68,13 @@ c10::optional GetBackendDevice(const at::Tensor& tensor) { return c10::nullopt; } +c10::optional GetBackendDevice(const c10::optional device) { + if (device) { + return c10::make_optional(atenDeviceToBackendDevice(*device)); + } + return c10::nullopt; +} + c10::optional GetBackendDevice() { return c10::nullopt; } diff --git a/torch/csrc/lazy/backend/backend_device.h b/torch/csrc/lazy/backend/backend_device.h index fbc1c48a6e7e75..818d0f2c10db27 100644 --- a/torch/csrc/lazy/backend/backend_device.h +++ b/torch/csrc/lazy/backend/backend_device.h @@ -62,6 +62,7 @@ TORCH_API c10::Device backendDeviceToAtenDevice(const BackendDevice& device); // input is not a lazy tensor. TORCH_API c10::optional GetBackendDevice(const at::TensorList tensors); TORCH_API c10::optional GetBackendDevice(const at::Tensor& tensor); +TORCH_API c10::optional GetBackendDevice(const c10::optional device); // For variadic template. TORCH_API c10::optional GetBackendDevice(); diff --git a/torchgen/api/lazy.py b/torchgen/api/lazy.py index ce8a3cdb6b14a5..ff74f4ab34bd86 100644 --- a/torchgen/api/lazy.py +++ b/torchgen/api/lazy.py @@ -170,10 +170,11 @@ class LazyArgument: def __init__(self, arg: Argument): self.name = arg.name self.orig_type = arg.type + self.is_optional = isinstance(arg.type, OptionalType) self.is_generator = isGeneratorType(arg.type) if self.is_generator: - assert isinstance( - arg.type, OptionalType + assert ( + self.is_optional ), "We expect all generators are optional since currently they are" # there is no handling for generators in TorchScript IR (or XLA) # so we fall back to eager if the (optional)generator has value, and otherwise diff --git a/torchgen/dest/lazy_ir.py b/torchgen/dest/lazy_ir.py index ab6ac4b80ea0f2..66c9e3d749eb54 100644 --- a/torchgen/dest/lazy_ir.py +++ b/torchgen/dest/lazy_ir.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import List, Union +from typing import List, Optional, Union from dataclasses import dataclass from torchgen.context import method_with_native_function from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup @@ -8,6 +8,7 @@ OptionalCType, VectorCType, kernel_signature, + deviceT, ) import torchgen.api.dispatcher as dispatcher from torchgen.api.lazy import ( @@ -350,11 +351,19 @@ def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str: def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str: value_args = schema.filtered_args(values=True, scalars=False) + scalar_args = schema.filtered_args(values=False, scalars=True) value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] + optional_device = OptionalCType(BaseCType(deviceT)) + optional_devices = [ + a.name for a in scalar_args if a.lazy_type == optional_device + ] assert ( - len(value_types_names) > 0 - ), "Code below assumes there is at least one tensor arg" - return f"""auto common_device = {self.get_device_fn}({', '.join(value_types_names)}); + len(value_types_names) > 0 or len(optional_devices) > 0 + ), "Expected at least one Value or Device type" + get_device_str = ( + f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})" + ) + return f"""auto common_device = {get_device_str}; TORCH_INTERNAL_ASSERT(common_device); """ @@ -406,10 +415,13 @@ def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str: }} """ - def create_lazy_tensor(self, first_tensor_name: str) -> str: + def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str: # xla uses an instance method for tensor creation, for the time being if self.create_from_first_tensor: # TODO(whc) remove this if XLA switches to using static method for creation + assert ( + first_tensor_name is not None + ), "Requires first tensor to create lazy tensor" return f"{first_tensor_name}.{self.create_tensor}" return f"{self.backend_namespace}::{self.create_tensor}" @@ -417,14 +429,14 @@ def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str: returns_length = len(schema.returns) value_args = schema.filtered_args(values=True, scalars=False) value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] - assert ( - len(value_types_names) > 0 - ), "Code below assumes there is at least one tensor arg" - first_tensor_name = value_types_names[0] + first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}( {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));""" if returns_length > 1: + assert ( + len(value_types_names) > 0 + ), "Code below assumes there is at least one tensor arg" bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors; for (int i = 0; i < {returns_length}; i++) {{ lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));