Skip to content

Commit

Permalink
clean up NamedTuple creation API (#28189)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#28189

This makes it a separate createNamed function. The existing API resulted
in poor usage in fbcode, which in turn caused bugs in TorchScript programs.

Test Plan: Imported from OSS

Differential Revision: D17970220

Pulled By: zdevito

fbshipit-source-id: 59b082a726f56bec1c8d10d410db829f4aa271ea
  • Loading branch information
zdevito authored and facebook-github-bot committed Oct 22, 2019
1 parent 03d24db commit 6d689e2
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 49 deletions.
19 changes: 8 additions & 11 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -800,19 +800,15 @@ using TupleTypePtr = std::shared_ptr<TupleType>;
using NameList = std::vector<std::string>;
// This type represents a Tuple
struct CAFFE2_API TupleType : public NamedType {
static std::shared_ptr<FunctionSchema> namedTupleSchemaFromNamesAndTypes(
c10::QualifiedName,
std::vector<std::string>,
std::vector<TypePtr>);

static TupleTypePtr createNamed(const c10::optional<c10::QualifiedName>& name,
const std::vector<std::string>& field_names,
const std::vector<TypePtr>& types);
static TupleTypePtr create(
std::vector<TypePtr> types,
c10::optional<c10::QualifiedName> name = c10::nullopt,
std::shared_ptr<FunctionSchema> schema = nullptr) {
std::vector<TypePtr> types) {
return TupleTypePtr(new TupleType(
std::move(types),
std::move(name),
std::move(schema))); // NOLINT(modernize-make-shared)
c10::nullopt,
nullptr)); // NOLINT(modernize-make-shared)
}

at::ArrayRef<TypePtr> elements() const {
Expand All @@ -832,7 +828,8 @@ struct CAFFE2_API TupleType : public NamedType {
}
TypePtr createWithContained(
std::vector<TypePtr> contained_types) const override {
return create(std::move(contained_types));
return std::shared_ptr<TupleType>(
new TupleType(std::move(contained_types), name(), schema()));
}
const std::shared_ptr<FunctionSchema>& schema() const {
return schema_;
Expand Down
13 changes: 7 additions & 6 deletions aten/src/ATen/core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,10 @@ std::ostream& operator<<(std::ostream & out, const VaryingShape & vs) {
return out;
}

std::shared_ptr<FunctionSchema> TupleType::namedTupleSchemaFromNamesAndTypes(
c10::QualifiedName qualName,
std::vector<std::string> field_names,
std::vector<TypePtr> field_types) {
TupleTypePtr TupleType::createNamed(
const c10::optional<c10::QualifiedName>& qualName,
const std::vector<std::string>& field_names,
const std::vector<TypePtr>& field_types) {
TORCH_INTERNAL_ASSERT(field_names.size() == field_types.size());
std::vector<Argument> arguments;
for (size_t i = 0; i < field_names.size(); ++i) {
Expand All @@ -454,11 +454,12 @@ std::shared_ptr<FunctionSchema> TupleType::namedTupleSchemaFromNamesAndTypes(
}

auto schema = std::make_shared<FunctionSchema>(
/*name=*/qualName.name(),
/*name=*/qualName.value_or(c10::QualifiedName()).name(),
/*overload_name=*/std::string(""),
/*arguments=*/arguments,
/*returns=*/std::vector<Argument>{});
return schema;
return std::shared_ptr<TupleType>(new TupleType(
field_types, qualName, schema)); // NOLINT(modernize-make-shared)
}

TupleType::TupleType(
Expand Down
6 changes: 1 addition & 5 deletions torch/csrc/jit/import_source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,7 @@ struct SourceImporterImpl : public Resolver,
field_types.emplace_back(std::move(type));
}

auto tt = TupleType::create(
field_types,
qualified_name,
TupleType::namedTupleSchemaFromNamesAndTypes(
qualified_name, field_names, field_types));
auto tt = TupleType::createNamed(qualified_name, field_names, field_types);
cu_->register_type(tt);
}

Expand Down
17 changes: 9 additions & 8 deletions torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1384,15 +1384,16 @@ Node* Graph::createWithSubgraph(Symbol kind) {
return n;
}

Node* Graph::createTuple(
at::ArrayRef<Value*> values,
c10::optional<c10::QualifiedName> qualname,
std::shared_ptr<FunctionSchema> schema) {
auto types = fmap(values, [](Value* v) { return v->type(); });
auto tt = TupleType::create(
std::move(types), std::move(qualname), std::move(schema));
Node* Graph::createTuple(at::ArrayRef<Value*> values, TupleTypePtr tuple_type) {
TORCH_INTERNAL_ASSERT(
!tuple_type || tuple_type->schema(),
"only pass tuple_type when creating a named tuple");
if (!tuple_type) {
auto types = fmap(values, [](Value* v) { return v->type(); });
tuple_type = TupleType::create(std::move(types));
}
auto n = create(prim::TupleConstruct, values);
n->output()->setType(tt);
n->output()->setType(tuple_type);
return n;
}

Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,7 @@ struct Graph {
TORCH_API Node* createDifferentiableSubgraph();
TORCH_API Node* createTuple(
at::ArrayRef<Value*> values,
c10::optional<c10::QualifiedName> qualname = c10::nullopt,
std::shared_ptr<FunctionSchema> schema=nullptr);
TupleTypePtr optional_named_tuple = nullptr);
TORCH_API Node* createTupleUnpack(Value* v);
TORCH_API Node* createTupleIndex(
Value* tup,
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/passes/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,10 +608,10 @@ class ShapePropagator {
// We refresh the tuple type, because the input types could have been
// refined.
auto orig_type = node->output()->type()->expect<TupleType>();
node->output()->setType(TupleType::create(
fmap(node->inputs(), [](Value* v) { return v->type(); }),
orig_type->name(),
orig_type->schema()));
auto new_types =
fmap(node->inputs(), [](Value* v) { return v->type(); });
node->output()->setType(
orig_type->createWithContained(std::move(new_types)));
return;
}
case prim::TupleUnpack: {
Expand Down
6 changes: 1 addition & 5 deletions torch/csrc/jit/script/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,7 @@ struct PythonResolver : public Resolver {
std::tuple<std::string, decltype(fields), decltype(annotations)>>(
props);

auto tt = TupleType::create(
annotations,
qualifiedName,
TupleType::namedTupleSchemaFromNamesAndTypes(
qualifiedName, fields, annotations));
auto tt = TupleType::createNamed(qualifiedName, fields, annotations);
if (auto type = get_python_cu()->get_type(qualifiedName)) {
TORCH_CHECK(
type->isSubtypeOf(tt),
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/jit/script/schema_matching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,13 @@ static Value* packOutputs(
return values[0];
}
std::shared_ptr<FunctionSchema> schema;
TupleTypePtr named_tuple = nullptr;
if (field_names) {
schema = TupleType::namedTupleSchemaFromNamesAndTypes(c10::QualifiedName(), field_names.value(), fmap(values, [](Value* v) { return v->type(); }));
auto types = fmap(values, [](Value* v) { return v->type(); });
named_tuple = TupleType::createNamed(
c10::nullopt, field_names.value(), std::move(types));
}
return g
.insertNode(
g.createTuple(values, c10::nullopt, std::move(schema)))
->output();
return g.insertNode(g.createTuple(values, named_tuple))->output();
}

// Given a successful match between operator schema and symbol, emit a node
Expand Down
4 changes: 1 addition & 3 deletions torch/csrc/jit/script/sugared_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,7 @@ std::shared_ptr<SugaredValue> NamedTupleConstructor::call(

auto self =
g.insertNode(
g.createTuple(
matched_schema.inputs, std::move(qualname), std::move(schema))
->setSourceRange(loc))
g.createTuple(matched_schema.inputs, type_)->setSourceRange(loc))
->output();
self->setType(type_);

Expand Down

0 comments on commit 6d689e2

Please sign in to comment.