Skip to content

Commit

Permalink
[RELAY] BiasAdd, MLP, Resnet testing (apache#1969)
Browse files Browse the repository at this point in the history
* [RELAY] BiasAdd, MLP, Resnet testing

* fix review comments
  • Loading branch information
tqchen authored Oct 24, 2018
1 parent 399b39f commit c76fce9
Show file tree
Hide file tree
Showing 28 changed files with 1,160 additions and 237 deletions.
13 changes: 13 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ This level enables fully connected multi-layer perceptron.
tvm.relay.nn.relu
tvm.relay.nn.dropout
tvm.relay.nn.batch_norm
tvm.relay.nn.bias_add



**Level 2: Convolutions**
Expand Down Expand Up @@ -85,8 +87,13 @@ This level enables additional math and transform operators.
tvm.relay.abs
tvm.relay.negative
tvm.relay.take
tvm.relay.zeros
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.full
tvm.relay.full_like
tvm.relay.cast


**Level 4: Broadcast and Reductions**
Expand Down Expand Up @@ -151,6 +158,9 @@ Level 1 Definitions
.. autofunction:: tvm.relay.nn.softmax
.. autofunction:: tvm.relay.nn.log_softmax
.. autofunction:: tvm.relay.nn.relu
.. autofunction:: tvm.relay.nn.dropout
.. autofunction:: tvm.relay.nn.batch_norm
.. autofunction:: tvm.relay.nn.bias_add


Level 2 Definitions
Expand Down Expand Up @@ -185,6 +195,9 @@ Level 3 Definitions
.. autofunction:: tvm.relay.zeros_like
.. autofunction:: tvm.relay.ones
.. autofunction:: tvm.relay.ones_like
.. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast


Level 4 Definitions
Expand Down
17 changes: 17 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@
namespace tvm {
namespace relay {

/*!
* \brief Add a 1D Tensor to an axis of a data.
*
* \note bias_add is a special add operator that is in nn
* and enables automatic derivation of bias's shape.
* You can directly use add for more generalized case.
*/
struct BiasAddAttrs : public tvm::AttrsNode<BiasAddAttrs> {
int axis;

TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") {
TVM_ATTR_FIELD(axis)
.describe("The axis to add the bias")
.set_default(1);
}
};

/*! \brief Attributes used in convolution operators */
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides;
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
namespace tvm {
namespace relay {

/*! \brief data type cast */
struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") {
TVM_ATTR_FIELD(dtype)
.describe("Target data type");
}
}; // struct CastAttrs.

/*! \brief Attributes used in expand_dims operators */
struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
int axis;
Expand Down
30 changes: 19 additions & 11 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,17 @@ class ExprFunctor<R(const Expr& n, Args...)> {
}
};

/*! \brief A simple visitor wrapper around ExprFunctor.
/*!
* \brief A simple visitor wrapper around ExprFunctor.
* Recursively visit the content.
*
* Exposes two visitors with default traversal strategies, one
* which doesn't compute a result but can mutate internal state,
* and another which functionally builds a new Expr.
* ExprVisitor treats Expr as dataflow graph,
* and only visit each Expr node once.
*/

class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
class ExprVisitor
: public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
public:
void VisitExpr(const Expr& expr) override;
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const GlobalVarNode* op) override;
void VisitExpr_(const ConstantNode* op) override;
Expand All @@ -132,13 +134,19 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
virtual void VisitType(const Type& t);

private:
// internal visited flag.
std::unordered_set<const Node*> visited_;
};

/*! \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
/*!
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once.
* The mutated results are memoized in a map and reused so that
* local transformation on the dataflow preserves the graph structure.
*/
class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public:
Expand Down
29 changes: 10 additions & 19 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,35 +102,26 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*/
bool WellFormed(const Expr& e);

/*! \brief Get free variables from expression e.
/*! \brief Get free Vars from expr in PostDFS order.
*
* Free variables are variables that are not bound by a let or a function parameter in the context.
* Free variables are variables that are not bound by a
* let or a function parameter in the context.
*
* \param e the expression.
* \param expr the expression.
*
* \return the set of free variable.
* \return List of free vars, in the PostDFS order visited by expr.
*/
tvm::Array<Var> FreeVariables(const Expr& e);
tvm::Array<Var> FreeVars(const Expr& expr);

/*! \brief Get free type parameters from expression e.
/*! \brief Get free TypeVars from expression expr.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
*
* \param e the expression.
* \param expr the expression.
*
* \return the set of free type variables.
* \return List of free vars, in the PostDFS order visited by expr.
*/
tvm::Array<TypeVar> FreeTypeVariables(const Expr& e);

/*! \brief Get free type parameters from type t.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
*
* \param t the type.
*
* \return the set of free type variables.
*/
tvm::Array<TypeVar> FreeTypeVariables(const Type& t);
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);

/*! \brief Remove expressions which does not effect the program result.
*
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def __init__(self, dtype, value):
self.__init_handle_by_constructor__(
_make.IntImm, dtype, value)

def __int__(self):
return self.value


@register_node
class UIntImm(ConstExpr):
Expand Down
32 changes: 28 additions & 4 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .base import RelayNode, register_relay_node
from . import _make
from . import ty as _ty
from .._ffi import base as _base, node as _node
from .._ffi import base as _base
from .. import nd as _nd
from .. import convert

Expand All @@ -28,6 +28,25 @@ def checked_type(self):
" the checked_type for this node")
return ret

def astype(self, dtype):
"""Cast the content type of the current data to dtype.
Parameters
----------
dtype : str
The target data type.
Note
----
This function only works for TensorType Exprs.
Returns
-------
result : tvm.relay.Expr
The result expression.
"""
return _make.dtype_cast(self, dtype)


@register_relay_node
class Constant(Expr):
Expand Down Expand Up @@ -62,6 +81,9 @@ def __getitem__(self, index):
def __len__(self):
return len(self.fields)

def astype(self, _):
raise TypeError("astype cannot be used on tuple")


@register_relay_node
class Var(Expr):
Expand Down Expand Up @@ -238,7 +260,7 @@ def __init__(self, tuple_value, index):
_make.TupleGetItem, tuple_value, index)


class TupleWrapper(_node.NodeGeneric):
class TupleWrapper(object):
"""TupleWrapper.
This class is a Python wrapper for a Relay tuple of known size.
Expand All @@ -257,10 +279,9 @@ def __init__(self, tuple_value, size):
self.tuple_value = tuple_value
self.size = size

def asnode(self):
def astuple(self):
"""Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function."""

return self.tuple_value

def __getitem__(self, index):
Expand All @@ -275,6 +296,9 @@ def __repr__(self):
return ("TupleWrapper(" + self.tuple_value.__repr__() +
", " + self.size + ")")

def astype(self, _):
raise TypeError("astype cannot be used on tuple")


def var(name_hint,
type_annotation=None,
Expand Down
40 changes: 23 additions & 17 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ def infer_type(expr, env=None):
Parameters
----------
expr: tvm.relay.Expr
The input expression.
The input expression.
env: Optional[tvm.relay.Environment]
The global environment.
The global environment.
Returns
-------
checked_expr : tvm.relay.Expr
The checked expression.
The checked expression.
"""
return _ir_pass.infer_type(expr, env)

Expand All @@ -35,12 +35,12 @@ def well_formed(expr):
Parameters
----------
expr: tvm.relay.Expr
The input expression
The input expression
Returns
-------
well_form : bool
whether the input expression is well formed
Whether the input expression is well formed
"""
return _ir_pass.well_formed(expr)

Expand All @@ -52,15 +52,15 @@ def check_kind(t, env=None):
Parameters
----------
t: tvm.relay.Type
The type to check
The type to check
env: tvm.relay.Environment, optional
The global environment
The global environment
Returns
-------
well_kinded : bool
whether the input type is well kinded.
whether the input type is well kinded.
Examples
--------
Expand All @@ -75,20 +75,26 @@ def check_kind(t, env=None):
return _ir_pass.check_kind(t)


def free_vars(e):
"""Get free variables from expression e.
def free_vars(expr):
"""Get free Vars from expression expr in Post DFS order.
Parameters
----------
e: tvm.relay.Expr
The input expression
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of free variables
The list of free variables in post DFS order.
Note
----
The fact that Vars are post-DFS ordred are useful in
neural networks: usually this means weights of previous
are ordered first.
"""
return _ir_pass.free_vars(e)
return _ir_pass.free_vars(expr)


def free_type_vars(expr):
Expand Down Expand Up @@ -130,15 +136,15 @@ def alpha_equal(lhs, rhs):
Parameters
----------
lhs: tvm.relay.Expr
One of the input Expression.
One of the input Expression.
rhs: tvm.relay.Expr
One of the input Expression.
One of the input Expression.
Returns
-------
result: bool
True iff lhs is alpha equal to rhs.
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))

Expand Down
Loading

0 comments on commit c76fce9

Please sign in to comment.