Skip to content

Commit

Permalink
Add gradient graph (dmlc#280)
Browse files Browse the repository at this point in the history
* Add creating gradient symbol

* Fix lint

* Address comments

* Fix typo

* Address comment
  • Loading branch information
kevinthesun authored and ZihengJiang committed Dec 30, 2017
1 parent 8e353d2 commit 5bb53ad
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 15 deletions.
16 changes: 12 additions & 4 deletions docs/top.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div
nnvm.symbol.fill
nnvm.symbol.fill_like
nnvm.symbol.full
nnvm.symbol.full_like
nnvm.symbol.ones
nnvm.symbol.ones_like
nnvm.symbol.zeros
nnvm.symbol.zeros_like
nnvm.symbol.flatten
nnvm.symbol.concatenate
nnvm.symbol.expand_dims
Expand Down Expand Up @@ -113,8 +117,12 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.fill
.. autofunction:: nnvm.symbol.fill_like
.. autofunction:: nnvm.symbol.full
.. autofunction:: nnvm.symbol.full_like
.. autofunction:: nnvm.symbol.ones
.. autofunction:: nnvm.symbol.ones_like
.. autofunction:: nnvm.symbol.zeros
.. autofunction:: nnvm.symbol.zeros_like
.. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims
Expand Down
36 changes: 36 additions & 0 deletions python/nnvm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ._base import GraphHandle, SymbolHandle
from ._base import check_call
from .symbol import Variable, Symbol, Group as _Group
from .symbol import ones_like

class GraphIndex(object):
"""Index for quickly accessing graph attributes.
Expand Down Expand Up @@ -270,3 +271,38 @@ def create(symbol):
check_call(_LIB.NNGraphCreate(
symbol.handle, ctypes.byref(ghandle)))
return Graph(ghandle)


def gradients(ys, xs, grad_ys=None):
"""Create gradient symbol of ys respect to xs.
Parameters
----------
ys : Symbol or list of Symbol
Symbols from which the gradient is calculated.
xs : Symbol or list of Symbol
Symbols the gradient respect to.
For group symbol, gradients for all outputs will be calculated.
grad_ys : Symbol or list of Symbol
Head gradients for ys.
Returns
-------
ret : list of Symbol
Generated gradient symbol. For each xs,
all gradients from ys are merged into a single symbol.
"""
if isinstance(ys, list):
ys = _Group(ys)
g = create(ys)
g._set_symbol_list_attr('grad_ys', ys)
g._set_symbol_list_attr('grad_xs', xs)
ny = len(ys.list_output_names())
if grad_ys is None:
grad_ys = [ones_like(ys[i]) for i in range(ny)]
g._set_symbol_list_attr('grad_ys_out_grad', grad_ys)
sym = g.apply('Gradient').symbol
nx = len(_Group(xs).list_output_names()) \
if isinstance(xs, list) else len(xs.list_output_names())
ret = [sym[i] for i in range(nx)]
return ret
4 changes: 2 additions & 2 deletions src/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
return std::move(v[0]);
} else if (v.size() == 0) {
NodePtr zero_node = Node::Create();
zero_node->attrs.op = Op::Get("__zero__");
zero_node->attrs.op = Op::Get("_zeros");
return NodeEntry{zero_node, 0, 0};
} else {
NodePtr sum_node = Node::Create();
sum_node->attrs.op = Op::Get("__ewise_sum__");
sum_node->attrs.op = Op::Get("elemwise_sum");
sum_node->inputs = std::move(v);
return NodeEntry{sum_node, 0, 0};
}
Expand Down
120 changes: 111 additions & 9 deletions src/top/tensor/elemwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div)
// grad_1 = - grad_y * n0 / n1^2
NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_0",
{ograds[0], n->inputs[0]});
NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1", {sub0});
NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1",
{sub0});
NodeEntry sub2 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_2",
{n->inputs[1], n->inputs[1]});
return std::vector<NodeEntry>{
Expand Down Expand Up @@ -240,15 +241,27 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)

DMLC_REGISTER_PARAMETER(InitOpParam);

// fill
NNVM_REGISTER_INIT_OP(fill)
// full
NNVM_REGISTER_INIT_OP(full)
.describe(R"code(Fill array with scalar value
)code" NNVM_ADD_FILELINE)
.set_support_level(1);

// fill_like
NNVM_REGISTER_ELEMWISE_UNARY_OP(fill_like)
NNVM_REGISTER_INIT_OP(zeros)
.describe(R"code(Fill target with zeros
)code" NNVM_ADD_FILELINE)
.set_support_level(1);

NNVM_REGISTER_INIT_OP(ones)
.describe(R"code(Fill target with ones
)code" NNVM_ADD_FILELINE)
.set_support_level(1);

// full_like
NNVM_REGISTER_ELEMWISE_UNARY_OP(full_like)
.describe(R"code(Return an scalar value array with the same shape and type
as the input array
Expand All @@ -260,8 +273,38 @@ as the input array
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("fill_like", n->attrs.name + "_zero",
{n->inputs[0]}, {{"value", "0"}})
MakeNode("zeros_like", n->attrs.name + "_grad",
{n->inputs[0]})
};
});

NNVM_REGISTER_ELEMWISE_UNARY_OP(zeros_like)
.describe(R"code(Return an array of zeros with the same shape and type
as the input array.
)code")
.add_argument("data", "Symbol", "The input")
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("zeros_like", n->attrs.name + "_grad",
{n->inputs[0]})
};
});

NNVM_REGISTER_ELEMWISE_UNARY_OP(ones_like)
.describe(R"code(Return an array of ones with the same shape and type
as the input array.
)code")
.add_argument("data", "Symbol", "The input")
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("zeros_like", n->attrs.name + "_grad",
{n->inputs[0]})
};
});

Expand Down Expand Up @@ -353,8 +396,10 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rdiv_scalar__)
// y = scalar / n0
// grad_0 = - grad_y * scalar / n0^2
NodeEntry sub0 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_0",
{ograds[0]}, {{"scalar", n->attrs.dict["scalar"]}});
NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1", {sub0});
{ograds[0]},
{{"scalar", n->attrs.dict["scalar"]}});
NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1",
{sub0});
NodeEntry sub2 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_2",
{n->inputs[0], n->inputs[0]});
return std::vector<NodeEntry>{
Expand Down Expand Up @@ -407,6 +452,63 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rpow_scalar__)
});


struct ElementWiseSumParam : public dmlc::Parameter<ElementWiseSumParam> {
int num_args;
DMLC_DECLARE_PARAMETER(ElementWiseSumParam) {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
.describe("Number of inputs to be summed.");
}
};

DMLC_REGISTER_PARAMETER(ElementWiseSumParam);

bool ElementWiseSumShape(const NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(out_attrs->size(), 1);
return ElemwiseAttr<TShape, shape_is_none, shape_assign, true, shape_string>(
attrs, in_attrs, out_attrs, TShape());
}

bool ElementWiseSumType(const NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(out_attrs->size(), 1);
return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_attrs, out_attrs, -1);
}

std::vector<NodeEntry> ElementWiseSumGrad(
const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
// identity constraints in the beginning for easier shape inference.
const Op* copy_op = Op::Get("identity");
CHECK_EQ(ograds.size(), 1);
std::vector<NodeEntry> ret;
NodeEntry n_out{n, 0, 0};
for (size_t i = 0; i < n->inputs.size(); i++) {
NodePtr id_node = Node::Create();
id_node->attrs.op = copy_op;
id_node->inputs = {ograds[0]};
ret.push_back(NodeEntry{id_node, 0, 0});
}
return ret;
}


NNVM_REGISTER_OP(elemwise_sum)
.describe(R"code(Adds all input arguments element-wise.
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ElementWiseSumParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
return ret;
})
.set_attr<nnvm::FInferShape>("FInferShape", ElementWiseSumShape)
.set_attr<nnvm::FInferType>("FInferType", ElementWiseSumType)
.set_attr<nnvm::FGradient>("FGradient", ElementWiseSumGrad)
.add_argument("args", "Symbol[]", "Positional input arguments");

} // namespace top
} // namespace nnvm
18 changes: 18 additions & 0 deletions tests/python/unittest/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ def test_print_graph_ir():
assert("y_bias" in ir1)
assert("shape=" in ir2)

def test_gradient():
x = sym.Variable("x")
y = sym.Variable("y")
z1 = sym.elemwise_add(x, sym.sqrt(y))
z2 = sym.log(x)
gradient = graph.gradients([z1, z2], [x, y])
assert len(gradient) == 2

g1 = sym.Variable("g1")
g2 = sym.Variable("g2")
grad_ys = [g1, g2]
gradient = graph.gradients(sym.Group([z1, z2]),
sym.Group([x, y]), grad_ys=grad_ys)
g_graph = graph.create(sym.Group(gradient)).ir()
assert len(gradient) == 2
assert "g1" in g_graph
assert "g2" in g_graph

if __name__ == "__main__":
test_print_graph_ir()
Expand All @@ -123,3 +140,4 @@ def test_print_graph_ir():
test_infer_type()
test_plan_memory()
test_list_args()
test_gradient()

0 comments on commit 5bb53ad

Please sign in to comment.