diff --git a/include/nnvm/node.h b/include/nnvm/node.h index c3698ed75..2c7d0ef30 100644 --- a/include/nnvm/node.h +++ b/include/nnvm/node.h @@ -146,11 +146,9 @@ inline NodeEntry MakeNode( NodePtr p = Node::Create(); p->attrs.op = nnvm::Op::Get(op_name); p->attrs.name = std::move(node_name); - if (attrs.size() != 0) { - p->attrs.dict = attrs; - if (p->attrs.op->attr_parser) { - p->attrs.op->attr_parser(&(p->attrs)); - } + p->attrs.dict = attrs; + if (p->attrs.op->attr_parser) { + p->attrs.op->attr_parser(&(p->attrs)); } p->inputs = std::move(inputs); return NodeEntry{p, 0, 0}; diff --git a/src/pass/plan_memory.cc b/src/pass/plan_memory.cc index b1f2a37d4..f96f061b5 100644 --- a/src/pass/plan_memory.cc +++ b/src/pass/plan_memory.cc @@ -151,6 +151,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, GraphAllocator* allocator) { static auto& finplace_option = Op::GetAttr("FInplaceOption"); static auto& finplace_identity = Op::GetAttr("FInplaceIdentity"); + static auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); // Get reference auto &storage = *storage_ptr; @@ -189,10 +190,13 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); auto sid_out = storage[eid_out]; auto sid_in = storage[eid_in]; + bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 && + fignore_inputs[inode.source->op()]( + inode.source->attrs).size() == inode.source->num_inputs()); if (taken[kv.first] == false && sid_out == GraphAllocator::kBadStorageID && sid_in >= 0 && - (storage_ref_count[sid_in] == 1 || identity[ipair]) && + (storage_ref_count[sid_in] == 1 && !ignore_all_inputs || identity[ipair]) && entry_ref_count[eid_out] > 0 && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && dtype_vec[eid_out] == dtype_vec[eid_in]) { @@ -230,7 +234,6 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, storage[eid] = sid; } // check if certain inputs is ignored. - static auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); std::vector ignore_inputs; if (fignore_inputs.count(inode.source->op()) != 0) { ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs); diff --git a/src/top/nn/nn.cc b/src/top/nn/nn.cc index 24510d252..79b4c8587 100644 --- a/src/top/nn/nn.cc +++ b/src/top/nn/nn.cc @@ -134,8 +134,11 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu) NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero", {n->inputs[0]}); return std::vector{ - MakeNode("greater", n->attrs.name + "_grad", - {n->inputs[0], zero}, {{"exclude", "true"}}) + MakeNode("elemwise_mul", n->attrs.name + "_grad", { + ograds[0], + MakeNode("greater", n->attrs.name + "_grad_mask", + {n->inputs[0], zero}, {{"exclude", "true"}}) + }) }; }) .set_support_level(1); @@ -249,7 +252,7 @@ axis to be the last item in the input shape. .set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { return 1; }) -.set_attr("FListMutateInputs", [](const NodeAttrs& attrs) { +.set_attr("FMutateInputs", [](const NodeAttrs& attrs) { return std::vector{3, 4}; }) .set_support_level(1); diff --git a/src/top/tensor/matrix_op.cc b/src/top/tensor/matrix_op.cc index 5c1632801..149c609ee 100644 --- a/src/top/tensor/matrix_op.cc +++ b/src/top/tensor/matrix_op.cc @@ -33,9 +33,9 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(lshape[lshape.ndim() - 1], rshape[0]) << "dot shape inconsistent: " << lshape << " X " << rshape; - TShape oshape(lshape.ndim() + rshape.ndim() - 1); - for (size_t i = 0; i < lshape.ndim() - 1; i++) oshape[i] = lshape[i]; - for (size_t i = 1; i < rshape.ndim(); i++) oshape[i + lshape.ndim() - 1] = rshape[i]; + TShape oshape(lshape.ndim() + rshape.ndim() - 2); + for (int i = 0; i < lshape.ndim() - 1; i++) oshape[i] = lshape[i]; + for (int i = 1; i < rshape.ndim(); i++) oshape[i + lshape.ndim() - 2] = rshape[i]; NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); return true; diff --git a/src/top/tensor/transform.cc b/src/top/tensor/transform.cc index 0bf1a91ec..dcdedbb86 100644 --- a/src/top/tensor/transform.cc +++ b/src/top/tensor/transform.cc @@ -574,7 +574,7 @@ the input array into an output array with the same shape as the second input arr )code" NNVM_ADD_FILELINE) .add_argument("data", "Tensor", "Input data.") .add_argument("shape_like", "Tensor", "Input data.") -.set_num_inputs(1) +.set_num_inputs(2) .set_num_outputs(1) .set_attr( "FInferShape", [](const NodeAttrs& attrs, @@ -585,7 +585,7 @@ the input array into an output array with the same shape as the second input arr NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, in_attrs->at(1)); return true; }) -.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferType", ElemwiseType<2, 1>) .set_attr( "FGradient", [](const NodePtr& n, const std::vector& ograds) { diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index 7af65c616..226dedd3d 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -23,6 +23,36 @@ def test_dense(): assert(sdict["fc_bias"][0] == [30]) +def test_matmul(): + a = sym.Variable('a', shape=(10, 20)) + b = sym.Variable('b', shape=(20, 30)) + c = sym.matmul(a, b, name="matmul") + sdict = infer_shape(c) + assert(sdict["matmul"][0] == [10, 30]) + a = sym.Variable('a', shape=(20, 10)) + c = sym.matmul(a, b, name="matmul", transpose_a=True) + sdict = infer_shape(c) + assert(sdict["matmul"][0] == [10, 30]) + b = sym.Variable('b', shape=(30, 20)) + c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True) + sdict = infer_shape(c) + assert(sdict["matmul"][0] == [10, 30]) + a = sym.Variable('a', shape=(10, 20)) + c = sym.matmul(a, b, name="matmul", transpose_b=True) + sdict = infer_shape(c) + assert(sdict["matmul"][0] == [10, 30]) + a = sym.Variable('a', shape=(10, 20, 30)) + b = sym.Variable('b', shape=(30, 40, 50)) + c = sym.matmul(a, b, name="matmul") + sdict = infer_shape(c) + assert(sdict["matmul"][0] == [10, 20, 40, 50]) + a = sym.Variable('a', shape=(30, 20, 10)) + b = sym.Variable('b', shape=(50, 40, 30)) + c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True) + sdict = infer_shape(c) + assert(sdict["matmul"][0] == [10, 20, 40, 50]) + + def test_concatenate(): x1 = sym.Variable("x", shape=(10, 20)) x2 = sym.Variable("y", shape=(10, 30)) @@ -275,6 +305,7 @@ def check(in_shape, out_shape, **kwargs): if __name__ == "__main__": test_expand_dims() test_dense() + test_matmul() test_concatenate() test_split() test_batchnorm() diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 4a8ae58c9..93d4fea26 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -6,6 +6,13 @@ def test_dense(): y = sym.dense(x, units=30, name="fc") assert y.list_input_names() == ["x", "fc_weight", "fc_bias"] +def test_batch_norm(): + x = sym.Variable('x') + y = sym.dense(x, units=30, name="fc") + z = sym.batch_norm(x, name='bn') + assert z.list_input_names('aux_state') == ['bn_moving_mean', 'bn_moving_var'] + assert z.list_input_names('read_only') == ['x', 'bn_gamma', 'bn_beta'] + def test_compose(): x = sym.Variable('x') z = sym.Variable('z') @@ -51,3 +58,4 @@ def test_op_name(): test_copy() test_default_input() test_compose() + test_batch_norm()