Skip to content

Commit

Permalink
fix restore layout in AlterOpLayout (dmlc#460)
Browse files Browse the repository at this point in the history
* fix restore layout in AlterOpLayout

* lint test case
  • Loading branch information
yzhliu authored and tqchen committed Apr 29, 2018
1 parent d011150 commit 60c685d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/compiler/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ Graph AlterOpLayout(const Graph& src) {
if (new_nodes.count(inode.source)) {
const std::vector<Layout>& in_layouts =
in_layouts_of_node[new_nodes[inode.source]];
for (const auto& e : inode.inputs) {
ret_layouts[ret_idx.entry_id(e)] = in_layouts[e.index];
for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
ret_layouts[ret_idx.entry_id(e)] = in_layouts[i];
}
const std::vector<Layout>& out_layouts =
out_layouts_of_node[new_nodes[inode.source]];
Expand Down
6 changes: 5 additions & 1 deletion tests/python/compiler/test_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ def test_alter_conv2d_layout():
conv = sym.conv2d(data, name="conv", channels=16,
kernel_size=(3,3), padding=(1,1),
use_bias=False, layout="NCHW")
relu = sym.relu(conv, name="relu")
# split here
convs = sym.split(conv, indices_or_sections=2)
relus = [sym.relu(x, name="relu") for x in convs]
relu = sym.concatenate(*relus)
flatten = sym.flatten(relu, name="flatten")
softmax = sym.softmax(flatten, name="softmax")
g = graph.create(softmax)

g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"])
Expand Down

0 comments on commit 60c685d

Please sign in to comment.