Skip to content

Commit

Permalink
[PYTHON] Check in a symbolic construction interface in python, start … (
Browse files Browse the repository at this point in the history
dmlc#4)

* [PYTHON] Check in a symbolic construction interface in python, start add graph API

* Graph API
  • Loading branch information
tqchen authored Jul 11, 2016
1 parent a5aee3a commit adc18b6
Show file tree
Hide file tree
Showing 18 changed files with 1,145 additions and 17 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ build
lib
*~
dmlc-core
cli_test
cli_test
*.pyc
74 changes: 70 additions & 4 deletions include/nnvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ typedef unsigned int nn_uint;
typedef void *AtomicSymbolCreator;
/*! \brief handle to a symbol that can be bind as operator */
typedef void *SymbolHandle;
/*! \brief handle to a AtomicSymbol */
typedef void *AtomicSymbolHandle;
/*! \brief handle to Graph */
typedef void *GraphHandle;

/*!
* \brief return str message of the last error
Expand Down Expand Up @@ -71,7 +71,7 @@ NNVM_DLL int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type = NULL);
const char **return_type);
/*!
* \brief Create an AtomicSymbol functor.
* \param creator the AtomicSymbolCreator
Expand Down Expand Up @@ -123,7 +123,18 @@ NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str);

/*!
* \brief Get string attribute from symbol
* \param symbol the source symbol
* \param key The key of the symbol.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int *success);
/*!
* \brief Set string attribute from symbol.
* NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph.
Expand Down Expand Up @@ -216,4 +227,59 @@ NNVM_DLL int NNSymbolCompose(SymbolHandle sym,
const char** keys,
SymbolHandle* args);

// Graph IR API
/*!
* \brief create a graph handle from symbol
* \param symbol The symbol representing the graph.
* \param graph The graph handle created.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph);
/*!
* \brief free the graph handle
* \param handle The handle to be freed.
*/
NNVM_DLL int NNGraphFree(GraphHandle handle);
/*!
* \brief Get a new symbol from the graph.
* \param graph The graph handle.
* \param symbol The corresponding symbol
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
/*!
* \brief Get Set a std::string typed attribute to graph.
* \param handle The graph handle.
* \param key The key to the attribute.
* \param value The value to be exposed.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphSetStrAttr(GraphHandle handle,
const char* key,
const char* value);
/*!
* \brief Get Set a std::string typed attribute from graph attribute.
* \param handle The graph handle.
* \param key The key to the attribute.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetStrAttr(SymbolHandle handle,
const char* key,
const char** out,
int *success);
/*!
* \brief Apply pass on the src graph.
* \param src The source graph handle.
* \param num_pass The number of pass to be applied.
* \param pass_names The names of the pass.
* \param dst The result graph.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphApplyPass(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst);

#endif // NNVM_C_API_H_
4 changes: 2 additions & 2 deletions include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,10 @@ inline Op& Op::attr( // NOLINT(*)
vec.resize(index_ + 1,
std::make_pair(ValueType(), 0));
std::pair<ValueType, int>& p = vec[index_];
CHECK(p.second == 0 || p.first == value)
CHECK(p.second == 0)
<< "Attribute " << attr_name
<< " of operator " << this->name
<< " is already registered to a different value";
<< " is already registered.";
vec[index_] = std::make_pair(value, 1);
});
return *this;
Expand Down
9 changes: 9 additions & 0 deletions include/nnvm/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ class Symbol {
* \param attrs The attributes to set.
*/
void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs);
/*!
* \brief Get attributes from the symbol.
* This only works for symbol with outputs from single operators.
* For grouped sybmbol, an error will be raised.
* \param key Key of the attribute. When key == "name", it returns the name attirbute.
* \param out the output value of the attribute.
* \return true if the attribute exists, false if the attribute do not exist.
*/
bool GetAttr(const std::string& key, std::string* out) const;
/*!
* \brief Get attribute dictionary from the symbol.
* For grouped sybmbol, an error will be raised.
Expand Down
10 changes: 10 additions & 0 deletions python/nnvm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/usr/bin/env python
# coding: utf-8
"""NNVM python API for ease of use and help new framework establish python API. """
from __future__ import absolute_import

from . import base
from . import symbol as sym
from . import symbol

__version__ = base.__version__
62 changes: 62 additions & 0 deletions python/nnvm/attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# coding: utf-8
"""Attribute scoping support for symbolic API."""
from __future__ import absolute_import

from .base import string_types

class AttrScope(object):
"""Attribute manager for scoping.
User can also inherit this object to change naming behavior.
Parameters
----------
kwargs
The attributes to set for all symbol creations in the scope.
"""
current = None

def __init__(self, **kwargs):
self._old_scope = None
for value in kwargs.values():
if not isinstance(value, string_types):
raise ValueError("Attributes need to be string")
self._attr = kwargs

def get(self, attr):
"""
Get the attribute dict given the attribute set by the symbol.
Parameters
----------
attr : dict of string to string
The attribute passed in by user during symbol creation.
Returns
-------
attr : dict of string to string
Updated attributes to add other scope related attributes.
"""
if self._attr:
ret = self._attr.copy()
if attr:
ret.update(attr)
return ret
else:
return attr

def __enter__(self):
# pylint: disable=protected-access
self._old_scope = AttrScope.current
attr = AttrScope.current._attr.copy()
attr.update(self._attr)
self._attr = attr
AttrScope.current = self
return self

def __exit__(self, ptype, value, trace):
assert self._old_scope
AttrScope.current = self._old_scope

AttrScope.current = AttrScope()

Loading

0 comments on commit adc18b6

Please sign in to comment.