Skip to content

Commit

Permalink
Fix 'Code below assumes there is at least one tensor arg' assumption (p…
Browse files Browse the repository at this point in the history
…ytorch#76917)

Previously when codegening ops like `zeros_` or `ones_` we'd hit a `Code below assumes there is at least one tensor arg error`. This check is not entirely correct which is what is causing the error to be thrown. There are ops like the ones mentioned that pass in a `device` parameter that can be used in place of the "first tensor".

CC: @wconstab @desertfire @henrytwo @ke1337
Pull Request resolved: pytorch#76917
Approved by: https://github.com/desertfire
  • Loading branch information
antoniojkim authored and pytorchmergebot committed May 18, 2022
1 parent a7cf95a commit 55be35a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
8 changes: 8 additions & 0 deletions torch/csrc/lazy/backend/backend_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <c10/util/StringUtil.h>
#include <torch/csrc/lazy/core/tensor.h>
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <c10/util/Optional.h>

namespace torch {
namespace lazy {
Expand Down Expand Up @@ -67,6 +68,13 @@ c10::optional<BackendDevice> GetBackendDevice(const at::Tensor& tensor) {
return c10::nullopt;
}

c10::optional<BackendDevice> GetBackendDevice(const c10::optional<c10::Device> device) {
if (device) {
return c10::make_optional(atenDeviceToBackendDevice(*device));
}
return c10::nullopt;
}

c10::optional<BackendDevice> GetBackendDevice() {
return c10::nullopt;
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/lazy/backend/backend_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ TORCH_API c10::Device backendDeviceToAtenDevice(const BackendDevice& device);
// input is not a lazy tensor.
TORCH_API c10::optional<BackendDevice> GetBackendDevice(const at::TensorList tensors);
TORCH_API c10::optional<BackendDevice> GetBackendDevice(const at::Tensor& tensor);
TORCH_API c10::optional<BackendDevice> GetBackendDevice(const c10::optional<c10::Device> device);

// For variadic template.
TORCH_API c10::optional<BackendDevice> GetBackendDevice();
Expand Down
5 changes: 3 additions & 2 deletions torchgen/api/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,11 @@ class LazyArgument:
def __init__(self, arg: Argument):
self.name = arg.name
self.orig_type = arg.type
self.is_optional = isinstance(arg.type, OptionalType)
self.is_generator = isGeneratorType(arg.type)
if self.is_generator:
assert isinstance(
arg.type, OptionalType
assert (
self.is_optional
), "We expect all generators are optional since currently they are"
# there is no handling for generators in TorchScript IR (or XLA)
# so we fall back to eager if the (optional)generator has value, and otherwise
Expand Down
30 changes: 21 additions & 9 deletions torchgen/dest/lazy_ir.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import List, Union
from typing import List, Optional, Union
from dataclasses import dataclass
from torchgen.context import method_with_native_function
from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
Expand All @@ -8,6 +8,7 @@
OptionalCType,
VectorCType,
kernel_signature,
deviceT,
)
import torchgen.api.dispatcher as dispatcher
from torchgen.api.lazy import (
Expand Down Expand Up @@ -350,11 +351,19 @@ def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:

def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False)
scalar_args = schema.filtered_args(values=False, scalars=True)
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
optional_device = OptionalCType(BaseCType(deviceT))
optional_devices = [
a.name for a in scalar_args if a.lazy_type == optional_device
]
assert (
len(value_types_names) > 0
), "Code below assumes there is at least one tensor arg"
return f"""auto common_device = {self.get_device_fn}({', '.join(value_types_names)});
len(value_types_names) > 0 or len(optional_devices) > 0
), "Expected at least one Value or Device type"
get_device_str = (
f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
)
return f"""auto common_device = {get_device_str};
TORCH_INTERNAL_ASSERT(common_device);
"""

Expand Down Expand Up @@ -406,25 +415,28 @@ def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
}}
"""

def create_lazy_tensor(self, first_tensor_name: str) -> str:
def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str:
# xla uses an instance method for tensor creation, for the time being
if self.create_from_first_tensor:
# TODO(whc) remove this if XLA switches to using static method for creation
assert (
first_tensor_name is not None
), "Requires first tensor to create lazy tensor"
return f"{first_tensor_name}.{self.create_tensor}"
return f"{self.backend_namespace}::{self.create_tensor}"

def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
returns_length = len(schema.returns)
value_args = schema.filtered_args(values=True, scalars=False)
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
assert (
len(value_types_names) > 0
), "Code below assumes there is at least one tensor arg"
first_tensor_name = value_types_names[0]
first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
{self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""

if returns_length > 1:
assert (
len(value_types_names) > 0
), "Code below assumes there is at least one tensor arg"
bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
for (int i = 0; i < {returns_length}; i++) {{
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
Expand Down

0 comments on commit 55be35a

Please sign in to comment.