Skip to content

Commit

Permalink
[API] Change attr to explicit name set_attr (dmlc#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Sep 14, 2016
1 parent 7bf1999 commit 8b15f64
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
36 changes: 18 additions & 18 deletions example/src/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ NNVM_REGISTER_OP(reshape)
CHECK(is >> target);
attrs->parsed = std::move(target);
})
.attr<FInferShape>(
.set_attr<FInferShape>(
"FInferShape", [] (const NodeAttrs& attrs,
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
Expand All @@ -78,7 +78,7 @@ NNVM_REGISTER_OP(reshape)
<< "Reshape op: source target shape mismatch";
return true;
})
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);


NNVM_REGISTER_OP(cast)
Expand All @@ -92,8 +92,8 @@ NNVM_REGISTER_OP(cast)
CHECK(is >> dtype);
attrs->parsed = std::move(dtype);
})
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInferType>(
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInferType>(
"FInferType", [](const NodeAttrs& attrs,
std::vector<int> *itype,
std::vector<int> *otype) {
Expand All @@ -104,8 +104,8 @@ NNVM_REGISTER_OP(cast)
NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>(
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
Expand All @@ -117,8 +117,8 @@ NNVM_REGISTER_OP(exp)
NNVM_REGISTER_OP(identity)
.describe("identity function")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>(
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{ograds[0]};
Expand All @@ -128,9 +128,9 @@ NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.add_alias("__add_symbol__")
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{ograds[0], ograds[0]};
Expand All @@ -139,9 +139,9 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP(mul)
.describe("multiply two data together")
.set_num_inputs(2)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
Expand All @@ -167,23 +167,23 @@ NNVM_REGISTER_OP(__one__)
NNVM_REGISTER_OP(cross_device_copy)
.describe("Copy data across device.")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape);
.set_attr<FInferShape>("FInferShape", SameShape);


NNVM_REGISTER_OP(conv2d)
.describe("take conv of input")
.set_num_inputs(2)
.attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "weight"};
});

NNVM_REGISTER_OP(add)
.attr<std::string>("nick_name", "plus");
.set_attr<std::string>("nick_name", "plus");

NNVM_REGISTER_OP(assign)
.set_num_inputs(2)
.set_num_outputs(1)
.attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
.set_attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0};
});

Expand Down
12 changes: 6 additions & 6 deletions include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* NNVM_REGISTER_OP(add)
* .describe("add two inputs together")
* .set_num_inputs(2)
* .attr<OpKernel>("gpu_kernel", AddKernel);
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
*
* NNVM_REGISTER_OP(sub)
* .describe("substract one tensor from another")
Expand All @@ -53,7 +53,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* // Can call regster multiple times in different files
* // to register different part of information
* NNVM_REGISTER_OP(sub)
* .attr<OpKernel>("gpu_kernel", SubKernel);
* .set_attr<OpKernel>("gpu_kernel", SubKernel);
*
* // get operators from registry.
* void my_function() {
Expand Down Expand Up @@ -213,8 +213,8 @@ class Op {
* \tparam ValueType The type of the value to be set.
*/
template<typename ValueType>
inline Op& attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value);
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
Expand Down Expand Up @@ -300,7 +300,7 @@ class OpMap {
* NNVM_REGISTER_OP(add)
* .describe("add two inputs together")
* .set_num_inputs(2)
* .attr<OpKernel>("gpu_kernel", AddKernel);
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
*
* \endcode
*/
Expand Down Expand Up @@ -329,7 +329,7 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
}

template<typename ValueType>
inline Op& Op::attr( // NOLINT(*)
inline Op& Op::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value) {
// update the attribute map of the key by creating new empty if needed.
UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) {
Expand Down

0 comments on commit 8b15f64

Please sign in to comment.