Skip to content

Commit

Permalink
allow missing FCorrectLayout (dmlc#457)
Browse files Browse the repository at this point in the history
* allow missing FCorrectLayout

* misunderstood OpMap[], fix
  • Loading branch information
yzhliu authored and tqchen committed Apr 28, 2018
1 parent 03827b1 commit d011150
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/pass/correct_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ nnvm::NodePtr CreateLayoutTransformNode(const Layout& src,
using LayoutAttrDict = std::unordered_map<const Node*, std::vector<Layout> >;

/*!
* \brief A simple layout infer pass that will
* \brief A simple layout infer & correct pass that will
* insert layout transform nodes automatically.
*/
nnvm::Graph CorrectLayout(nnvm::Graph src) {
static auto& op_infer_layout =
static auto& op_correct_layout =
nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout");

const IndexedGraph& idx = src.indexed_graph();
Expand Down Expand Up @@ -91,14 +91,13 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
}
}

const auto& flayout = op_infer_layout[new_node->op()];
CHECK(flayout != nullptr) << "Attribute FCorrectLayout"
<< " is not registered by op " << inode.source->op()->name
<< " we are not able to complete layout transform.";
CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts))
if (op_correct_layout.count(new_node->op())) {
const auto &flayout = op_correct_layout[new_node->op()];
CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts))
<< "Layout infer fail";
CHECK_EQ(request_ilayouts.size(), num_inputs);
CHECK_EQ(produce_olayouts.size(), num_outputs);
CHECK_EQ(request_ilayouts.size(), num_inputs);
CHECK_EQ(produce_olayouts.size(), num_outputs);
}

// update new layouts
new_layouts[new_node.get()] = std::move(produce_olayouts);
Expand Down

0 comments on commit d011150

Please sign in to comment.