Skip to content

Commit

Permalink
[ATTR/SYMBOL] Expose op_name attr to python (dmlc#132)
Browse files Browse the repository at this point in the history
* [ATTR/SYMBOL] Expose op_name attr to python

* fix xcode
  • Loading branch information
tqchen authored Jul 28, 2017
1 parent cbe2070 commit e9c2672
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ os:
- linux
- osx

osx_image: xcode8

env:
# code analysis
- TASK=lint
Expand Down
3 changes: 3 additions & 0 deletions src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ bool Symbol::GetAttr(const std::string& key, std::string* out) const {
if (key == "name") {
*out = node->attrs.name;
return true;
} else if (key == "op_name") {
*out = node->attrs.op->name;
return true;
}
auto it = node->attrs.dict.find(key);
if (it == node->attrs.dict.end()) return false;
Expand Down
10 changes: 10 additions & 0 deletions tests/python/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def test_copy():
name='exp', gpu=1, attr={"kk": "1"})
assert y.__copy__().debug_str() == y.debug_str()


def test_op_name():
x = sym.Variable('x')
y = sym.exp(x)
op_name = y.attr("op_name")
op_func = sym.__dict__[op_name]
z = op_func(x)


def test_control_dep():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv')
Expand All @@ -53,6 +62,7 @@ def test_control_dep():
t._add_control_deps([z, y])

if __name__ == "__main__":
test_op_name()
test_copy()
test_default_input()
test_compose()
Expand Down

0 comments on commit e9c2672

Please sign in to comment.