Skip to content

Commit

Permalink
add symbol::GetChildren (dmlc#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong authored Feb 25, 2017
1 parent 767f818 commit 9d6b4e4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
2 changes: 2 additions & 0 deletions amalgamation/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
nnvm.d
nnvm.cc
6 changes: 6 additions & 0 deletions include/nnvm/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ class Symbol {
* including input variables and intermediate outputs.
*/
Symbol GetInternals() const;
/*
* \brief Get the direct inputs of the head node(s) of this symbol.
* \return symbol A new symbol whose output contains all the inputs of the head
* node(s).
*/
Symbol GetChildren() const;
/*!
* \brief Set additional attributes to current node.
*
Expand Down
13 changes: 13 additions & 0 deletions src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,19 @@ Symbol Symbol::GetInternals() const {
return ret;
}

Symbol Symbol::GetChildren() const {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol ret;
std::unordered_set<Node*> visited;
for (const auto& p : this->outputs) {
Node* node = p.node.get();
if (visited.count(node)) continue;
visited.insert(node);
ret.outputs.insert(ret.outputs.end(), node->inputs.begin(), node->inputs.end());
}
return ret;
}

void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs) {
Node* node = outputs[0].node.get();
for (const NodeEntry& e : outputs) {
Expand Down

0 comments on commit 9d6b4e4

Please sign in to comment.