Skip to content

Commit

Permalink
Remove more uses of DimensionedTensorType
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#23060

Differential Revision: D16460391

Pulled By: Krovatkin

fbshipit-source-id: b50ee87d22ad18b8cbfff719b199ea876ef172f1
  • Loading branch information
Krovatkin authored and facebook-github-bot committed Aug 2, 2019
1 parent 3314d60 commit 3d15ee1
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 44 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ struct CAFFE2_API ProfiledTensorType : public TensorType {
return requires_grad_ ? *requires_grad_ : false;
}


bool operator==(const Type& rhs) const override {
if (rhs.kind() != kind()) {
return false;
Expand Down
22 changes: 10 additions & 12 deletions torch/csrc/jit/fuser/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ static std::string variableType(const std::shared_ptr<c10::Type>& t) {
return "double";
} else if (t->kind() == TypeKind::BoolType) {
return "bool";
} else if (t->kind() == TypeKind::DimensionedTensorType) {
auto const tt = t->cast<DimensionedTensorType>();
return calcScalarTypeName(tt->scalarType());
} else if (auto scalar_type = ProfiledTensorType::create(t)->scalarType()) {
return calcScalarTypeName(*scalar_type);
}
// something went wrong with the type analysis during shape propagation
throw std::runtime_error(
Expand All @@ -116,9 +115,8 @@ static std::string typeCastedValueName(
// cast here, which may end up being a no-op if the tensor's scalar type
// is `double`.
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
} else if (t->kind() == TypeKind::DimensionedTensorType) {
auto const tt = t->cast<DimensionedTensorType>();
if (tt->scalarType() != outtype) {
} else if (auto scalar_type = ProfiledTensorType::create(t)->scalarType()) {
if (*scalar_type != outtype) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
Expand Down Expand Up @@ -261,25 +259,25 @@ static std::string encodeRHS(const Node* n) {
return encodeSpecialRHS(n, env);
} else {
size_t i = 0;
auto outtype = n->output()
->type()
->expect<c10::DimensionedTensorType const>()
->scalarType();

auto outtype =
ProfiledTensorType::create(n->output()->type())->scalarType();
TORCH_INTERNAL_ASSERT(outtype);

for (auto in : n->inputs()) {
// PyTorch converts (scalar) argument types to result before applying the
// operator e.g. 1.4-torch.tensor(3) = -2
env.s(
std::to_string(i),
typeCastedValueName(in->type(), outtype, valueName(in)));
typeCastedValueName(in->type(), *outtype, valueName(in)));
// Uncasted operands only used for comparison operators
env.s(std::to_string(i) + "_nocast", valueName(in));
i++;
}

const auto& templ = simple_map_ops.at(n->kind());
const char* str = nullptr;
if (outtype == at::kFloat) {
if (*outtype == at::kFloat) {
str = templ.for_float;
} else {
str = templ.for_double;
Expand Down
10 changes: 8 additions & 2 deletions torch/csrc/jit/fuser/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ std::shared_ptr<FusedKernel> compileKernel(

for (size_t i = 0; i < input_desc.size(); i++) {
const auto& desc = input_desc[i];

// TODO: can't get rid of this use of DimensionedTensorType yet
// until we switch to ProfilingGraphExecutor, so we don't have to
// run PropagateInputShapes below
graph->inputs()[i]->setType(DimensionedTensorType::create(
desc.scalar_type,
device,
Expand Down Expand Up @@ -247,8 +251,10 @@ std::shared_ptr<FusedKernel> compileKernel(
if (o->node()->kind() == prim::FusedConcat) {
sizes.at(o->node()->i(attr::dim)) *= o->node()->inputs().size();
}
auto scalar_type = o->type()->expect<c10::DimensionedTensorType const>()->scalarType();
auto type = CompleteTensorType::create(scalar_type, device, sizes);

auto scalar_type = ProfiledTensorType::create(o->type())->scalarType();
TORCH_INTERNAL_ASSERT(scalar_type);
auto type = CompleteTensorType::create(*scalar_type, device, sizes);
output_desc.emplace_back(type);
const auto& desc = output_desc.back();

Expand Down
10 changes: 7 additions & 3 deletions torch/csrc/jit/passes/decompose_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ bool isDecomposableNorm(Node* normalize_op) {
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor",
};
Value* input = normalize_op->namedInput(attr::input);
auto tensor_type = input->type()->cast<DimensionedTensorType>();
// As of now, we do the decomposition for batchnorm/layernorm on GPU device only
if (!tensor_type || tensor_type->device().is_cpu()) {
if (!input->type()->isSubtypeOf(TensorType::get())) {
return false;
}
auto device = ProfiledTensorType::create(input->type())->device();
// As of now, we do the decomposition for batchnorm/layernorm on GPU device
// only
if (!device || (*device).is_cpu()) {
return false;
}

Expand Down
11 changes: 7 additions & 4 deletions torch/csrc/jit/passes/graph_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,16 @@ struct GraphFuser {
}

bool isFusableDevice(Value *v) {
auto tensor_type = v->type()->cast<DimensionedTensorType>();
if (!tensor_type) {
if (!v->type()->isSubtypeOf(TensorType::get())) {
return true;
}
if (tensor_type->device().is_cpu()) {
auto device = ProfiledTensorType::create(v->type())->device();
if (!device) {
return true;
}
if ((*device).is_cpu()) {
return canFuseOnCPU();
} else if (tensor_type->device().is_cuda()) {
} else if ((*device).is_cuda()) {
return canFuseOnGPU();
}
throw std::runtime_error("Unknown device");
Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ Node* InsertCastForCond(Value* cond_val, Graph* graph, Node* consumer_node) {

bool IsCondCastRequired(Value* cond_val) {
const auto& type = cond_val->type();
if (type->isSubclass(TypeKind::DimensionedTensorType)) {
return type->expect<DimensionedTensorType>()->scalarType() != c10::kBool;
if (type->isSubtypeOf(TensorType::get())) {
if (auto scalar_type = ProfiledTensorType::create(type)->scalarType()) {
return *scalar_type != c10::kBool;
}
}
return !type->isSubclass(TypeKind::BoolType);
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/pybind_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ inline IValue toIValue(
return repeated;
}
case TypeKind::DimensionedTensorType:
case TypeKind::ProfiledTensorType:
case TypeKind::TensorType:
return c10::impl::toList(py::cast<std::vector<at::Tensor>>(obj));
default:
Expand Down
11 changes: 8 additions & 3 deletions torch/csrc/jit/python_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,11 @@ void initPythonIRBindings(PyObject* module_) {
.def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
.def(
"dim",
[](const Type& t) {
return t.expect<DimensionedTensorType>()->dim();
[](Type& t) {
auto vshape =
ProfiledTensorType::create(t.shared_from_this())->sizes();
return vshape.size() ? py::cast(*vshape.size())
: py::cast<py::none>(Py_None);
})
.def(
"sizes",
Expand All @@ -642,7 +645,9 @@ void initPythonIRBindings(PyObject* module_) {
.def(
"scalarType",
[](Type& t) {
return toString(t.expect<DimensionedTensorType>()->scalarType());
auto scalar_type =
ProfiledTensorType::create(t.shared_from_this())->scalarType();
return (scalar_type) ? toString(*scalar_type) : nullptr;
})
.def(
"__eq__",
Expand Down
1 change: 0 additions & 1 deletion torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
if sys.version_info[0] > 2:
import pathlib


def _parse_env(name, default, true_message, false_message):
value = os.environ.get(name)
if value is None:
Expand Down
17 changes: 13 additions & 4 deletions torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@
# TensorType - This is a Tensor, but we don't know anything about its
# properties (e.g. scalar type, # dims, shapes).
# Appears as `Tensor` in graph print-outs.
# ProfiledTensorType <: TensorType - Denotes a Tensor for which we know the
# concrete sizes in addition to the information
# contained in TensorTyper. This adds a sizes()
# method which can be used to retrieve the
# concrete sizes.
# @deprecated
# DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar
# type and number of dimensions, but not the concrete
# shapes. For example, appears as 'Float(*, *)' in
# graph print-outs. Useful accessor methods include
# dim() and scalarType()
# @deprecated
# CompleteTensorType <: DimensionedTensorType - Denotes a Tensor for which we know the
# concrete sizes in addition to the information
# contained in TensorTyper. This adds a sizes()
Expand Down Expand Up @@ -166,11 +173,13 @@ def _if_scalar_type_as(g, self, tensor):
"""
if isinstance(self, torch._C.Value):
return self
elif _is_complete_or_dimensioned_tensor_type(tensor):
ty = tensor.type().scalarType().lower()

scalar_type = tensor.type().scalarType()
if scalar_type:
ty = scalar_type.lower()
return getattr(self, ty)()
else:
return self

return self


def _is_value(x):
Expand Down
28 changes: 15 additions & 13 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def unused(g):
n.setType(OptionalType.ofTensor())
return n


def _shape_as_tensor(g, input):
return g.op('Shape', input)

Expand Down Expand Up @@ -419,14 +418,15 @@ def squeeze(g, self, dim=None):
# Handle negative dims
for i, dim in enumerate(dims):
if dim < 0:
if sym_help._is_complete_or_dimensioned_tensor_type(self):
rank = self.type().dim()
if rank:
warnings.warn("ONNX export squeeze with negative axis " + str(dim) +
" might cause the onnx model to be incorrect. " +
"Negative axis is not supported in ONNX. " +
"Axis is converted to " + str(dim + self.type().dim()) +
"Axis is converted to " + str(dim + rank) +
" based on input shape at export time. " +
"Passing an tensor of different rank in execution will be incorrect.")
dims[i] += self.type().dim()
dims[i] += rank
else:
return _unimplemented('squeeze', 'negative axis with unknown input rank')

Expand Down Expand Up @@ -498,15 +498,17 @@ def softmax(g, input, dim, dtype=None):
# their semantics are equivalent.
# So use softmax when dim and axis both equal to ndim - 1
# otherwise compute softmax using a subgraph with other operators
if sym_help._is_complete_or_dimensioned_tensor_type(input):
input_dim = input.type().dim()
if input_dim:
if dim < 0:
dim = input.type().dim() + dim
if input.type().dim() == dim + 1:
dim = input_dim + dim
if input_dim == dim + 1:
softmax = g.op('Softmax', input, axis_i=dim)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return softmax

exp = g.op('Exp', input)
sum = g.op('ReduceSum', exp, axes_i=[dim])
softmax = g.op('Div', exp, sum)
Expand All @@ -515,7 +517,6 @@ def softmax(g, input, dim, dtype=None):
softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return softmax


@parse_args('v', 't', 'v')
def softplus(g, self, beta, threshold):
if beta != 1:
Expand Down Expand Up @@ -977,14 +978,14 @@ def selu(g, input):
def index_select(g, self, dim, index):
# In case of a scaler index, index_select returns a tensor with the same rank as the input.
# To match this bahavior in ONNX, we make index a 1D tensor so that the following gather
# also produces a tensor with the same rank as the input.
# also produces a tensor with the same rank as the input.
index_const = sym_help._maybe_get_scalar(index)
if not sym_help._is_value(index_const):
# Index is a constant scalar. Make it a size 1 constant tensor.
index = g.op("Constant", value_t=torch.LongTensor([index_const]))
elif sym_help._is_complete_or_dimensioned_tensor_type(index):
if index.type().dim() == 0:
# Index is a scalar. Reshape it to a size 1 tensor.
# Index is a scalar. Reshape it to a size 1 tensor.
index = g.op("Reshape", index, g.op("Constant", value_t=torch.LongTensor([1])))
return g.op("Gather", self, index, axis_i=dim)

Expand Down Expand Up @@ -1231,14 +1232,15 @@ def alias(g, self):
def unsqueeze(g, self, dim):
# Handle negative dim
if dim < 0:
if sym_help._is_complete_or_dimensioned_tensor_type(self):
rank = self.type().dim()
if rank:
warnings.warn("ONNX export unsqueeze with negative axis " + str(dim) +
" might cause the onnx model to be incorrect. " +
"Negative axis is not supported in ONNX. " +
"Axis is converted to " + str(dim + self.type().dim() + 1) +
"Axis is converted to " + str(dim + rank + 1) +
" based on input shape at export time. " +
"Passing an tensor of different rank in execution will be incorrect.")
dim = dim + self.type().dim() + 1
dim = dim + rank + 1
else:
return _unimplemented('unsqueeze', 'negative axis with unknown input rank')

Expand Down

0 comments on commit 3d15ee1

Please sign in to comment.