diff --git a/amalgamation/.gitignore b/amalgamation/.gitignore new file mode 100644 index 000000000..e808ea276 --- /dev/null +++ b/amalgamation/.gitignore @@ -0,0 +1,2 @@ +nnvm.d +nnvm.cc diff --git a/include/nnvm/symbolic.h b/include/nnvm/symbolic.h index 4d26947ca..ab836ad5e 100644 --- a/include/nnvm/symbolic.h +++ b/include/nnvm/symbolic.h @@ -137,6 +137,12 @@ class Symbol { * including input variables and intermediate outputs. */ Symbol GetInternals() const; + /* + * \brief Get the direct inputs of the head node(s) of this symbol. + * \return symbol A new symbol whose output contains all the inputs of the head + * node(s). + */ + Symbol GetChildren() const; /*! * \brief Set additional attributes to current node. * diff --git a/src/core/symbolic.cc b/src/core/symbolic.cc index a101deed7..51cc1fa9e 100644 --- a/src/core/symbolic.cc +++ b/src/core/symbolic.cc @@ -435,6 +435,19 @@ Symbol Symbol::GetInternals() const { return ret; } +Symbol Symbol::GetChildren() const { + static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); + Symbol ret; + std::unordered_set visited; + for (const auto& p : this->outputs) { + Node* node = p.node.get(); + if (visited.count(node)) continue; + visited.insert(node); + ret.outputs.insert(ret.outputs.end(), node->inputs.begin(), node->inputs.end()); + } + return ret; +} + void Symbol::SetAttrs(const std::vector >& attrs) { Node* node = outputs[0].node.get(); for (const NodeEntry& e : outputs) {