Skip to content

Commit

Permalink
Add init member to ReduceNode (apache#6138)
Browse files Browse the repository at this point in the history
- This patch adds a new member to ReduceNode called init which allows
  initialization with a custom ProducerLoad or a Float/Int immediate.
- This allows initialization of the output Tensor of a reduction with
  another Tensor instead of the `identity_element` defined in the
  CommReducer
- One example use case for this node is to initialize the Output of a
  convolution reduction with the Bias values thereby saving the
  Bias-add computation.
  • Loading branch information
quic-sanirudh authored Aug 27, 2020
1 parent 415c088 commit c6dd26b
Show file tree
Hide file tree
Showing 20 changed files with 268 additions and 64 deletions.
9 changes: 7 additions & 2 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,8 @@ class ReduceNode : public PrimExprNode {
CommReducer combiner;
/*! \brief The source operand */
Array<PrimExpr> source;
/*! \brief The init operand */
Array<PrimExpr> init;
/*! \brief The reduction axis */
Array<IterVar> axis;
/*!
Expand All @@ -1040,6 +1042,7 @@ class ReduceNode : public PrimExprNode {
v->Visit("dtype", &dtype);
v->Visit("combiner", &combiner);
v->Visit("source", &source);
v->Visit("init", &init);
v->Visit("axis", &axis);
v->Visit("condition", &condition);
v->Visit("value_index", &value_index);
Expand All @@ -1049,14 +1052,16 @@ class ReduceNode : public PrimExprNode {
// check axis first so IterVars can define the necessary variables.
return equal(dtype, other->dtype) && equal(axis, other->axis) &&
equal(combiner, other->combiner) && equal(source, other->source) &&
equal(condition, other->condition) && equal(value_index, other->value_index);
equal(init, other->init) && equal(condition, other->condition) &&
equal(value_index, other->value_index);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(axis);
hash_reduce(combiner);
hash_reduce(source);
hash_reduce(init);
hash_reduce(condition);
hash_reduce(value_index);
}
Expand All @@ -1072,7 +1077,7 @@ class ReduceNode : public PrimExprNode {
class Reduce : public PrimExpr {
public:
TVM_DLL Reduce(CommReducer combiner, Array<PrimExpr> src, Array<IterVar> rdom, PrimExpr condition,
int value_index);
int value_index, Array<PrimExpr> init);

TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode);
};
Expand Down
18 changes: 12 additions & 6 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,48 +464,54 @@ TVM_DLL PrimExpr isinf(PrimExpr x);
* \brief sum of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
* \param init The value with which to initialize the output.
* \return The result.
*/
TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis);
TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});

/*!
* \brief logical And of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
* \param init The value with which to initialize the output.
*/
TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis);
TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});

/*!
* \brief logical Or of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
* \param init The value with which to initialize the output.
* \return The result.
*/
TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis);
TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});

/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
* \param init The value with which to initialize the output.
* \return The result.
*/
TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis);
TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});

/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
* \param init The value with which to initialize the output.
* \return The result.
*/
TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis);
TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});

/*!
* \brief product of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
* \param init The value with which to initialize the output.
* \return The result.
*/
TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis);
TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});

/*!
* \brief Calculate floor(x)
Expand Down
19 changes: 11 additions & 8 deletions include/tvm/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ namespace topi {
using namespace tvm::te;

/*! \brief The operation to use for CommReduce */
using FReduce = std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis)>;
using FReduce =
std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis, Array<PrimExpr> init)>;

/*! \brief The operation to use for CommReduceIdx */
using FCommReduce = std::function<Array<PrimExpr>(Array<PrimExpr> exprs, const Array<IterVar>& axis,
Expand Down Expand Up @@ -158,7 +159,7 @@ inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array<PrimExp
arg_counter++;
}

return func(data(eval_range), r_axes);
return func(data(eval_range), r_axes, {});
};

return tvm::te::compute(target_shape, compute, data->op->name + "_red", kCommReduce);
Expand Down Expand Up @@ -284,23 +285,25 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity,
auto combiner = tvm::tir::CommReducer(lhs, rhs, result, id_elem);
Array<PrimExpr> outputs;
for (size_t i = 0; i < exprs.size(); ++i) {
outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast<int>(i)));
outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast<int>(i), {}));
}
return outputs;
};
}

/*! \brief Wrap tvm::min to ensure we get the correct overload */
inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis) { return tvm::min(source, axis); }
inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}) {
return tvm::min(source, axis, init);
}

/*! \brief Wrap tvm::max to ensure we get the correct overload */
inline PrimExpr MaxOp(PrimExpr source, Array<IterVar> axis) {
return tvm::max(source, axis); // NOLINT(*)
inline PrimExpr MaxOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}) {
return tvm::max(source, axis, init); // NOLINT(*)
}

/*! \brief Wrap tvm::prod to ensure we get the correct overload */
inline PrimExpr ProdOp(PrimExpr source, Array<IterVar> axis) {
return tvm::prod(source, axis); // NOLINT(*)
inline PrimExpr ProdOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}) {
return tvm::prod(source, axis, init); // NOLINT(*)
}

/*!
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,14 @@ class Reduce(PrimExprWithOp):
value_index : int
The value index.
init : list of Expr
The initial value for output. This can be an int, float or ProducerLoad
"""
def __init__(self, combiner, src, rdom, condition, value_index):
def __init__(self, combiner, src, rdom, condition, value_index, init=None):
self.__init_handle_by_constructor__(
_ffi_api.Reduce, combiner, src, rdom,
condition, value_index)
condition, value_index, init)


@tvm._ffi.register_object
Expand Down
29 changes: 24 additions & 5 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,10 +1239,12 @@ def _reduce_directly(*args):
res = fcombine(res, args[i+1])
return res

def _make_reduce(expr, axis, where=None):
def _make_reduce(expr, axis, where=None, init=None):
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
expr = convert(expr)
if init is not None:
init = convert(init)
if isinstance(expr, Array):
size = len(expr)
larr = []
Expand All @@ -1255,6 +1257,16 @@ def _make_reduce(expr, axis, where=None):
larr.append(Var(lname, dtype))
rname = code.co_varnames[1] + "_" + str(i)
rarr.append(Var(rname, dtype))
if init is not None:
init = convert(init)
assert isinstance(init, Array)
assert len(init) == size
for init_i in range(size):
init_i = convert(init_i)
assert isinstance(init_i,
(tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm))
else:
init = convert([])
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
Expand All @@ -1270,21 +1282,28 @@ def _make_reduce(expr, axis, where=None):
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
if init is not None:
assert isinstance(init, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm))
init = convert([init])
result = convert(result)
id_elem = convert(id_elem)
combiner = CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
if where is None:
where = convert(True)
outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i)
for i in range(size))
if init is None:
outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, convert([]))
for i in range(size))
else:
outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, init)
for i in range(size))
return outputs[0] if size == 1 else outputs

# pylint: disable=keyword-arg-before-vararg
def reducer(expr, axis, where=None, *args):
def reducer(expr, axis, where=None, init=None, *args):
if isinstance(axis, (tvm.tir.IterVar, list, tuple)):
assert not args
return _make_reduce(expr, axis, where)
return _make_reduce(expr, axis, where, init)
if where is None:
assert not args
return _reduce_directly(expr, axis)
Expand Down
12 changes: 10 additions & 2 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,8 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
for (size_t i = 0; i < used.size(); ++i) {
if (SideEffect(op->source[i]) > CallEffectKind::kReadState ||
SideEffect(op->combiner->identity_element[i]) > CallEffectKind::kReadState ||
SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState) {
SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState ||
(!op->init.empty() && SideEffect(op->init[i]) > CallEffectKind::kReadState)) {
mark_used(i);
}
}
Expand All @@ -1024,6 +1025,7 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
Array<Var> new_lhs;
Array<Var> new_rhs;
Array<PrimExpr> new_source;
Array<PrimExpr> new_init;

// new stuff is old stuff which is used
for (size_t i = 0; i < used.size(); ++i) {
Expand All @@ -1034,14 +1036,15 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
new_lhs.push_back(op->combiner->lhs[i]);
new_rhs.push_back(op->combiner->rhs[i]);
new_source.push_back(op->source[i]);
if (!op->init.empty()) new_init.push_back(op->init[i]);
} else if (static_cast<int>(i) < op->value_index) {
// value_index should also be adjusted
new_value_index--;
}
}

CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity);
return Reduce(new_combiner, new_source, op->axis, op->condition, new_value_index);
return Reduce(new_combiner, new_source, op->axis, op->condition, new_value_index, new_init);
}

PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
Expand All @@ -1051,6 +1054,11 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
// already been simplified by const reduction axis removal
if (op == nullptr) return ret;
if (op->axis.empty()) {
if (!op->init.empty()) {
return this->VisitExpr(Select(op->condition,
(*op->combiner.get())(op->init, op->source)[op->value_index],
op->init[op->value_index]));
}
// Note that here we assume that the identity element is indeed identity. Without this
// assumption we would have to perform a single iteration of the loop, i.e. use
// `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
Expand Down
2 changes: 1 addition & 1 deletion src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ Doc TIRTextPrinter::VisitExpr_(const ShuffleNode* op) {
Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) {
Doc doc;
doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis)
<< ", " << op->value_index << ")";
<< ", " << op->value_index << ", " << Print(op->init) << ")";
return doc;
}

Expand Down
21 changes: 12 additions & 9 deletions src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -806,9 +806,9 @@ PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& ou

// Perform simplification mainly to remove a possibly empty reduction.
arith::Analyzer analyzer;
return analyzer.Simplify(
Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
kSimplifyRewriteCanonicalRewrite);
return analyzer.Simplify(Reduce(red->combiner, new_source, new_axis, All(res->dst->relations),
red->value_index, red->init),
kSimplifyRewriteCanonicalRewrite);
} else {
return expr;
}
Expand Down Expand Up @@ -938,6 +938,7 @@ class RemoveRedundantInequalitiesMutator : public ExprMutator {

virtual PrimExpr VisitExpr_(const ReduceNode* op) {
Array<PrimExpr> known_with_axes = known_;
CHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
for (const PrimExpr& axis_cond : IterVarsToInequalities(op->axis)) {
known_with_axes.push_back(axis_cond);
}
Expand All @@ -956,7 +957,7 @@ class RemoveRedundantInequalitiesMutator : public ExprMutator {
new_source.push_back(new_mutator(src));
}

return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index);
return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index, op->init);
}

virtual PrimExpr VisitExpr_(const EQNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
Expand Down Expand Up @@ -1068,13 +1069,14 @@ class ReductionAsTensorAccessMutator : public ExprMutator {
ReductionAsTensorAccessMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_),
Merge(vranges_, IterVarsToMap(op->axis)), name_);

CHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
Array<PrimExpr> new_source;
for (const PrimExpr& src : op->source) {
new_source.push_back(new_mutator(src));
}

PrimExpr new_reduce =
Reduce(op->combiner, new_source, op->axis, op->condition, op->value_index);
Reduce(op->combiner, new_source, op->axis, op->condition, op->value_index, op->init);

Array<Var> undefined_vars = UndefinedVars(new_reduce);
std::unordered_set<const VarNode*> undefined_var_set;
Expand Down Expand Up @@ -1133,7 +1135,7 @@ PrimExpr LiftReductions(const PrimExpr& expr, const Array<Var>& outer_axis,
}
PrimExpr new_condition = ReductionAsTensorAccess(red->condition, new_outer_axis, new_vranges);

return Reduce(red->combiner, new_source, red->axis, new_condition, red->value_index);
return Reduce(red->combiner, new_source, red->axis, new_condition, red->value_index, red->init);
} else {
return ReductionAsTensorAccess(expr, outer_axis, vranges);
}
Expand All @@ -1150,6 +1152,7 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
PrimExpr expr = analyzer.Simplify(expr_orig, kSimplifyRewriteCanonicalRewrite);

if (const ReduceNode* red = expr.as<ReduceNode>()) {
CHECK(red->init.empty()) << "Derivative of Reduction with initialization is not implemented";
// TODO(sgrechanik-h): There are some other operations which behave like sum
bool is_sum = IsSumCombiner(red->combiner, vranges);
if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index, vranges)) {
Expand All @@ -1167,7 +1170,7 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
source.Set(0, nz.value);
}

new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index);
new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index, red->init);
new_red = SimplifyReductionDomain(new_red, combined_vranges);
// If the reduction disappears completely then transform the result as a non-reduction
if (!new_red.as<ReduceNode>()) {
Expand All @@ -1193,8 +1196,8 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype())));
}

PrimExpr new_reduce =
Reduce(red->combiner, new_source, red->axis, new_reduce_cond, red->value_index);
PrimExpr new_reduce = Reduce(red->combiner, new_source, red->axis, new_reduce_cond,
red->value_index, red->init);
new_reduce =
TrySimplifyCompute(new_reduce, new_outer_cond, IterVarsToVars(axis), combined_vranges);
result = Select(new_outer_cond, new_reduce, make_zero(new_reduce.dtype()));
Expand Down
Loading

0 comments on commit c6dd26b

Please sign in to comment.