Skip to content

Commit

Permalink
Preserve method parameter names (#16750)
Browse files Browse the repository at this point in the history
Summary:
Fixes #16591

This uses uniqueBaseName so that parameters do not end up with suffixes. It changes next_id to be per-base-name rather than global to fix jittering issues when re-importing a re-numbered graph.
Pull Request resolved: pytorch/pytorch#16750

Differential Revision: D13960282

Pulled By: zdevito

fbshipit-source-id: 2156f581d9b95d77bf1f1252074e800b19116555
  • Loading branch information
zdevito authored and facebook-github-bot committed Feb 5, 2019
1 parent f8d4a14 commit 6efa40e
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 66 deletions.
34 changes: 17 additions & 17 deletions test/expect/TestFuser.test_lstm_concat_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
graph(%input_1 : Float(*, *)
%input : Float(*, *)
graph(%input : Float(*, *)
%input0 : Float(*, *)
%cx : Float(*, *)
%weight_1 : Float(*, *)
%weight : Float(*, *)
%bias_1 : Float(*)
%bias : Float(*)) {
%7 : Float(*, *) = aten::t(%weight_1)
%8 : Float(*, *) = aten::mm(%input_1, %7)
%9 : Float(*, *) = aten::t(%weight)
%10 : Float(*, *) = aten::mm(%input, %9)
%11 : Tensor[] = prim::ListConstruct(%bias_1, %8, %bias, %10)
%weight0 : Float(*, *)
%bias : Float(*)
%bias0 : Float(*)) {
%7 : Float(*, *) = aten::t(%weight)
%8 : Float(*, *) = aten::mm(%input, %7)
%9 : Float(*, *) = aten::t(%weight0)
%10 : Float(*, *) = aten::mm(%input0, %9)
%11 : Tensor[] = prim::ListConstruct(%bias, %8, %bias0, %10)
%12 : Tensor[] = aten::broadcast_tensors(%11)
%13 : Tensor, %14 : Tensor, %15 : Tensor, %16 : Tensor = prim::ListUnpack(%12)
%17 : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
Expand Down Expand Up @@ -37,15 +37,15 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%31 : Float(*, *) = aten::add(%27, %23, %21)
%32 : Float(*, *) = aten::add(%28, %24, %21)
%33 : Float(*, *) = aten::add(%29, %25, %21)
%ingate : Float(*, *) = aten::sigmoid(%30)
%forgetgate : Float(*, *) = aten::sigmoid(%31)
%cellgate : Float(*, *) = aten::tanh(%32)
%outgate : Float(*, *) = aten::sigmoid(%33)
%38 : Float(*, *) = aten::mul(%forgetgate, %0)
%39 : Float(*, *) = aten::mul(%ingate, %cellgate)
%ingate0 : Float(*, *) = aten::sigmoid(%30)
%forgetgate0 : Float(*, *) = aten::sigmoid(%31)
%cellgate0 : Float(*, *) = aten::tanh(%32)
%outgate0 : Float(*, *) = aten::sigmoid(%33)
%38 : Float(*, *) = aten::mul(%forgetgate0, %0)
%39 : Float(*, *) = aten::mul(%ingate0, %cellgate0)
%cy : Float(*, *) = aten::add(%38, %39, %21)
%41 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate, %41)
%hy : Float(*, *) = aten::mul(%outgate0, %41)
%43 : Float(*, *) = prim::FusedConcat[dim=0](%hy, %cy)
return (%43);
}
34 changes: 17 additions & 17 deletions test/expect/TestFuser.test_lstm_traced_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
graph(%input_1 : Float(*, *)
%input : Float(*, *)
graph(%input : Float(*, *)
%input0 : Float(*, *)
%cx : Float(*, *)
%weight_1 : Float(*, *)
%weight : Float(*, *)
%bias_1 : Float(*)
%bias : Float(*)) {
%7 : Float(*, *) = aten::t(%weight_1)
%8 : Float(*, *) = aten::mm(%input_1, %7)
%9 : Float(*, *) = aten::t(%weight)
%10 : Float(*, *) = aten::mm(%input, %9)
%11 : Tensor[] = prim::ListConstruct(%bias_1, %8, %bias, %10)
%weight0 : Float(*, *)
%bias : Float(*)
%bias0 : Float(*)) {
%7 : Float(*, *) = aten::t(%weight)
%8 : Float(*, *) = aten::mm(%input, %7)
%9 : Float(*, *) = aten::t(%weight0)
%10 : Float(*, *) = aten::mm(%input0, %9)
%11 : Tensor[] = prim::ListConstruct(%bias, %8, %bias0, %10)
%12 : Tensor[] = aten::broadcast_tensors(%11)
%13 : Tensor, %14 : Tensor, %15 : Tensor, %16 : Tensor = prim::ListUnpack(%12)
%17 : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
Expand Down Expand Up @@ -38,14 +38,14 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%31 : Float(*, *) = aten::add(%27, %23, %21)
%32 : Float(*, *) = aten::add(%28, %24, %21)
%33 : Float(*, *) = aten::add(%29, %25, %21)
%ingate : Float(*, *) = aten::sigmoid(%30)
%forgetgate : Float(*, *) = aten::sigmoid(%31)
%cellgate : Float(*, *) = aten::tanh(%32)
%outgate : Float(*, *) = aten::sigmoid(%33)
%38 : Float(*, *) = aten::mul(%forgetgate, %0)
%39 : Float(*, *) = aten::mul(%ingate, %cellgate)
%ingate0 : Float(*, *) = aten::sigmoid(%30)
%forgetgate0 : Float(*, *) = aten::sigmoid(%31)
%cellgate0 : Float(*, *) = aten::tanh(%32)
%outgate0 : Float(*, *) = aten::sigmoid(%33)
%38 : Float(*, *) = aten::mul(%forgetgate0, %0)
%39 : Float(*, *) = aten::mul(%ingate0, %cellgate0)
%cy : Float(*, *) = aten::add(%38, %39, %21)
%41 : Float(*, *) = aten::tanh(%cy)
%42 : Float(*, *) = aten::mul(%outgate, %41)
%42 : Float(*, *) = aten::mul(%outgate0, %41)
return (%42, %cy);
}
16 changes: 8 additions & 8 deletions test/expect/TestJit.test_pretty_printer-loop_use_test.expect
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
def graph(self,
y_1: Tensor) -> Tuple[Tensor, Tensor]:
x = torch.add(y_1, 1, 1)
z_1 = torch.add(x, 5, 1)
y, z = y_1, z_1
_0 = bool(torch.lt(y_1, 8))
y: Tensor) -> Tuple[Tensor, Tensor]:
x = torch.add(y, 1, 1)
z = torch.add(x, 5, 1)
y0, z0 = y, z
_0 = bool(torch.lt(y, 8))
while _0:
y_2 = torch.add_(y, 1, 1)
_0, y, z = bool(torch.lt(y_2, 8)), y_2, x
return (x, z)
y1 = torch.add_(y0, 1, 1)
_0, y0, z0 = bool(torch.lt(y1, 8)), y1, x
return (x, z0)
22 changes: 11 additions & 11 deletions test/expect/TestJit.test_pretty_printer-while_if_test.expect
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
def graph(self,
a_1: Tensor,
b_1: Tensor) -> Tensor:
a, b, c = a_1, b_1, 0
_0 = bool(torch.lt(a_1, 10))
a: Tensor,
b: Tensor) -> Tensor:
a0, b0, c = a, b, 0
_0 = bool(torch.lt(a, 10))
while _0:
a_2 = torch.add(a, 1, 1)
b_2 = torch.add(b, 1, 1)
if bool(torch.gt(a_2, b_2)):
c_2 = 2
a1 = torch.add(a0, 1, 1)
b1 = torch.add(b0, 1, 1)
if bool(torch.gt(a1, b1)):
c0 = 2
else:
c_2 = 3
_0, a, b, c = bool(torch.lt(a_2, 10)), a_2, b_2, c_2
return torch.add(torch.add(a, 1, 1), c, 1)
c0 = 3
_0, a0, b0, c = bool(torch.lt(a1, 10)), a1, b1, c0
return torch.add(torch.add(a0, 1, 1), c, 1)
16 changes: 8 additions & 8 deletions test/expect/TestJit.test_pretty_printer-while_test.expect
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
def graph(self,
a_1: Tensor,
i_1: Tensor) -> Tensor:
a, i = a_1, i_1
_0 = bool(torch.lt(i_1, 3))
a: Tensor,
i: Tensor) -> Tensor:
a0, i0 = a, i
_0 = bool(torch.lt(i, 3))
while _0:
a_2 = torch.mul_(a, a)
i_2 = torch.add_(i, 1, 1)
_0, a, i = bool(torch.lt(i_2, 3)), a_2, i_2
return a
a1 = torch.mul_(a0, a0)
i1 = torch.add_(i0, 1, 1)
_0, a0, i0 = bool(torch.lt(i1, 3)), a1, i1
return a0
9 changes: 4 additions & 5 deletions torch/csrc/jit/passes/python_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,16 +359,17 @@ struct PythonPrintPass {
buildConstantList(n, constants);
buildConstantList(b->return_node(), constants);
}

// get a new name unique across calls to uniqueName() and
// anything we have used.
size_t next_id = 0;
std::unordered_map<std::string, size_t> next_id;

std::string genNameImpl(
const std::string& candidate,
std::unordered_set<std::string>& used) {
std::string name = candidate;
while (used.count(name) || reserved_names.count(name)) {
name = candidate + std::to_string(next_id++);
name = candidate + std::to_string(next_id[name]++);
}
used.insert(name);
return name;
Expand Down Expand Up @@ -402,7 +403,7 @@ struct PythonPrintPass {
// use the uniqueName if it was set, otherwise generate a name.
std::string genUniqueNameFor(Value* v) {
return genName(
v->hasUniqueName() ? makeValidIdentifier(v->uniqueName()) : "_");
v->hasUniqueName() ? makeValidIdentifier(v->uniqueNameBase()) : "_");
}

// map from Value to how it should be printed at each use
Expand Down Expand Up @@ -1006,7 +1007,6 @@ struct PythonPrintPass {
}
void printMethod(script::Method& method) {
std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
;
createTensorToParameterNameMap(
method.owner(), QualifiedName::create("self"), parameter_names);
printMethod(method, parameter_names);
Expand All @@ -1027,7 +1027,6 @@ struct PythonPrintPass {
}
void printModule(script::Module& module) {
std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
;
createTensorToParameterNameMap(
module, QualifiedName::create("self"), parameter_names);
for (auto& method : module.get_methods()) {
Expand Down

0 comments on commit 6efa40e

Please sign in to comment.