Skip to content

Commit

Permalink
!45166 refine Transpose emitter
Browse files Browse the repository at this point in the history
Merge pull request !45166 from looop5/refine_bprop_transpose
  • Loading branch information
it-is-a-robot authored and gitee-org committed Nov 8, 2022
2 parents 44ad905 + 733903a commit 2f58693
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,10 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
}
out_shp = ib->GetShape(dout);
auto perm_1 = GenerateShapeIndex(out_shp, ind_shp, axis_v);
auto values_transpose = ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(perm_1)});
auto values_transpose = ib->Transpose(dout, perm_1);
auto tmp = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_v])});
auto perm_2 = GenerateInverseIndex(x_shp, axis_v);
auto params_grad = ib->Emit("Transpose", {tmp, ib->Value<ShapeVector>(perm_2)});
auto params_grad = ib->Transpose(tmp, perm_2);
return {params_grad, ib->ZerosLike(orig_indices), ib->ZerosLike(axis)};
}

Expand Down
13 changes: 13 additions & 0 deletions mindspore/ccsrc/common/graph_kernel/bprop/expander/emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ NodePtr Emitter::BatchMatMul(const NodePtr &a, const NodePtr &b, bool transpose_
{"transpose_b", MakeValue(transpose_b)}});
}

NodePtr Emitter::Transpose(const NodePtr &node, const ShapeVector &perm) const {
// perm like [0, 1, 2, 3] does not need transpose.
auto n = SizeToLong(perm.size());
for (size_t i = 0; i < perm.size(); ++i) {
// perm value may be negative, e.g. [0, -3, 2, 3] is equal to [0, 1, 2, 3]
auto perm_i = perm[i] < 0 ? (perm[i] + n) : perm[i];
if (perm_i != static_cast<int64_t>(i)) {
return Emit(kTransposeOpName, {node, Value(perm)});
}
}
return node;
}

NodePtr Emitter::ZerosLike(const NodePtr &node) const {
if (node->isa<ValueNode>()) {
auto value_node = node->get<ValueNodePtr>();
Expand Down
4 changes: 1 addition & 3 deletions mindspore/ccsrc/common/graph_kernel/bprop/expander/emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ class Emitter {
NodePtr Reciprocal(const NodePtr &node) const { return Emit(prim::kReciprocal, {node}); }
NodePtr Square(const NodePtr &node) const { return Emit(prim::kSquare, {node}); }
NodePtr Sign(const NodePtr &node) const { return Emit(prim::kPrimSign->name(), {node}); }
NodePtr Transpose(const NodePtr &node, const ShapeVector &perm) const {
return Emit(kTransposeOpName, {node, Value(perm)});
}
NodePtr Transpose(const NodePtr &node, const ShapeVector &perm) const;
NodePtr Tile(const NodePtr &node, const ShapeVector &multiples) const {
return Emit(kTileOpName, {node, Value(multiples)});
}
Expand Down
14 changes: 7 additions & 7 deletions mindspore/ccsrc/common/graph_kernel/bprop/grad/grad_array_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ REG_BPROP_BUILDER("SparseGatherV2").SetBody([](const BpropIRBuilder *ib) -> Node
out_shp = ib->GetShape(dout);
ind_shp = ib->GetShape(indices);
auto perm_1 = GenerateShapeIndex(out_shp, ind_shp, axis_int);
auto values_transpose = ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(perm_1)});
auto values_transpose = ib->Transpose(dout, perm_1);
auto params_grad = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_int])});
auto perm_2 = GenerateInverseIndex(x_shp, axis_int);
params_grad = ib->Emit("Transpose", {params_grad, ib->Value<ShapeVector>(perm_2)});
params_grad = ib->Transpose(params_grad, perm_2);
return {params_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
});

Expand All @@ -207,7 +207,7 @@ REG_BPROP_BUILDER("Sort").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
auto top_k_input = input_x;
if ((static_cast<size_t>(axis + 1) != rank)) {
transposition = GetTransposition(axis, rank);
top_k_input = ib->Emit("Transpose", {input_x, ib->Value<ShapeVector>(transposition)});
top_k_input = ib->Transpose(input_x, transposition);
}
auto tmp = ib->Emit("TopK", {top_k_input, ib->Value<int64_t>(k)});
auto indices = ib->TupleGetItem(tmp, 1);
Expand All @@ -228,12 +228,12 @@ REG_BPROP_BUILDER("Sort").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
auto x_shape_1d = ib->Value<ShapeVector>({x_size});
NodePtr dx = nullptr;
if (!transposition.empty()) {
auto invert_perm = ib->Value<ShapeVector>(InvertPermutation(transposition));
dvalue = ib->Emit("Transpose", {dvalue, invert_perm});
auto invert_perm = InvertPermutation(transposition);
dvalue = ib->Transpose(dvalue, invert_perm);
auto ind_expand = ib->Emit("ExpandDims", {ind, ib->Value<int64_t>(-1)});
auto scatter = ib->Emit("ScatterNd", {ind_expand, ib->Reshape(dvalue, {-1}), x_shape_1d});
auto out_grad = ib->Reshape(scatter, top_k_input_shape);
dx = ib->Emit("Transpose", {out_grad, invert_perm});
dx = ib->Transpose(out_grad, invert_perm);
} else {
auto ind_expand = ib->Emit("ExpandDims", {ind, ib->Value<int64_t>(-1)});
auto scatter = ib->Emit("ScatterNd", {ind_expand, ib->Reshape(dvalue, {-1}), x_shape_1d});
Expand Down Expand Up @@ -817,7 +817,7 @@ REG_BPROP_BUILDER("Transpose").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
(void)std::transform(tmp_perm.begin(), tmp_perm.end(), std::back_inserter(new_perm),
[&tmp_perm](const int64_t v) { return v >= 0 ? v : v + tmp_perm.size(); });
auto res_perm = InvertPermutation(new_perm);
return {ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(res_perm)}), ib->ZerosLike(perm)};
return {ib->Transpose(dout, res_perm), ib->ZerosLike(perm)};
});

REG_BPROP_BUILDER("Slice").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
Expand Down
16 changes: 8 additions & 8 deletions mindspore/ccsrc/common/graph_kernel/bprop/grad/grad_math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,8 @@ REG_BPROP_BUILDER("Cdist").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
}
perm.push_back(dout_dim - 1);
perm.push_back(dout_dim - 2);
auto dout_transpose = ib->Emit("Transpose", {dout, ib->Tensor(perm)});
auto out_transpose = ib->Emit("Transpose", {out, ib->Tensor(perm)});
auto dout_transpose = ib->Transpose(dout, perm);
auto out_transpose = ib->Transpose(out, perm);
auto dx = ib->Emit("CdistGrad", {dout, input_x, input_y, out}, {{"p", ib->GetAttr("p")}});
auto dy = ib->Emit("CdistGrad", {dout_transpose, input_y, input_x, out_transpose}, {{"p", ib->GetAttr("p")}});
return {dx, dy};
Expand Down Expand Up @@ -926,15 +926,15 @@ REG_BPROP_BUILDER("ReduceProd").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
auto tile_scaling = TupleDiv(input_shape, output_shape_kept_dims);
auto grad = ib->Emit("Tile", {dout, ib->Value<ShapeVector>(tile_scaling)});
auto [pack_shape, perm] = SplitShapeIndex(input_shape, GetAxisValue(axis));
auto permuted = ib->Emit("Transpose", {x, ib->Value<ShapeVector>(perm)});
auto permuted = ib->Transpose(x, perm);
auto permuted_shape = ib->GetShape(permuted);
auto reshaped = ib->Reshape(permuted, pack_shape);
auto left = ib->Emit("CumProd", {reshaped, ib->Tensor(0, ib->GetDtype(axis))},
{{"exclusive", MakeValue(true)}, {"reverse", MakeValue(false)}});
auto right = ib->Emit("CumProd", {reshaped, ib->Tensor(0, ib->GetDtype(axis))},
{{"exclusive", MakeValue(true)}, {"reverse", MakeValue(true)}});
auto y = ib->Reshape(ib->Mul(left, right), permuted_shape);
out = ib->Mul((ib->Emit("Transpose", {y, ib->Value<ShapeVector>(InvertPermutation(perm))})), grad);
out = ib->Mul(ib->Transpose(y, InvertPermutation(perm)), grad);
auto dx = ib->Reshape(out, input_shape);
return {dx, ib->ZerosLike(axis)};
});
Expand Down Expand Up @@ -1143,7 +1143,7 @@ REG_BPROP_BUILDER("MatrixExp").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
auto input_perm = Range(x_len);
input_perm[x_len - 2] = x_len - 1;
input_perm[x_len - 1] = x_len - 2;
auto x_transpose = ib->Emit("Transpose", {x, ib->EmitValue(MakeValue(input_perm))});
auto x_transpose = ib->Transpose(x, input_perm);
auto zero_matrix = ib->ZerosLike(x);
zero_matrix = ib->Cast(zero_matrix, ib->GetDtype(dout));
auto meta_grad_up = ib->Emit("Concat", {ib->MakeTuple({x_transpose, dout})}, {{"axis", MakeValue<int64_t>(-1)}});
Expand Down Expand Up @@ -1186,14 +1186,14 @@ REG_BPROP_BUILDER("CholeskyInverse").SetBody([](const BpropIRBuilder *ib) -> Nod
input_x = ib->Cast(input_x, kFloat32);
out = ib->Cast(out, kFloat32);
dout = ib->Cast(dout, kFloat32);
auto common_term = ib->Add(dout, ib->Emit("Transpose", {dout, ib->EmitValue(MakeValue(input_perm))}));
auto common_term = ib->Add(dout, ib->Transpose(dout, input_perm));
common_term = ib->Cast(common_term, kFloat32);
common_term = ib->MatMul(out, ib->MatMul(common_term, out, false, false), false, false);
DealWithUpper(common_term);
dx = ib->Cast(dx, kFloat64);
return {dx};
}
auto common_term = ib->Add(dout, ib->Emit("Transpose", {dout, ib->EmitValue(MakeValue(input_perm))}));
auto common_term = ib->Add(dout, ib->Transpose(dout, input_perm));
common_term = ib->MatMul(out, ib->MatMul(common_term, out, false, false), false, false);
DealWithUpper(common_term);
return {dx};
Expand Down Expand Up @@ -1386,7 +1386,7 @@ REG_BPROP_BUILDER("TridiagonalMatMul").SetBody([](const BpropIRBuilder *ib) -> N
}
perm.emplace_back(rank - 1);
perm.emplace_back(rank - 2);
return ib->Emit("Transpose", {x, ib->Value<ShapeVector>(perm)});
return ib->Transpose(x, perm);
};
auto superdiag = ib->GetInput(kIndex0);
auto maindiag = ib->GetInput(kIndex1);
Expand Down
6 changes: 3 additions & 3 deletions mindspore/ccsrc/common/graph_kernel/bprop/grad/grad_nn_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -978,11 +978,11 @@ REG_BPROP_BUILDER("Softmax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
auto dout = ib->GetInput(kIndex2);
auto shp = ib->GetShape(x);
auto reverse_axis = GetTransposeAxis(shp, one_axis);
out = ib->Emit("Transpose", {out, ib->Value<ShapeVector>(reverse_axis)});
dout = ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(reverse_axis)});
out = ib->Transpose(out, reverse_axis);
dout = ib->Transpose(dout, reverse_axis);
ShapeVector reduce_axis = {-1};
auto dx = ib->Mul(out, ib->Sub(dout, ib->ReduceSum(ib->Mul(out, dout), reduce_axis, true)));
dx = ib->Emit("Transpose", {dx, ib->Value<ShapeVector>(reverse_axis)});
dx = ib->Transpose(dx, reverse_axis);
return {dx};
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ REG_BPROP_BUILDER("SparseTensorDenseMatmul").SetBody([](const BpropIRBuilder *ib
auto dout = ib->GetInput(kIndex5);
auto dense_grad = ib->Emit("SparseTensorDenseMatmul", {indices, values, dense_shape, dout},
{{"adjoint_st", MakeValue(!adj_s)}, {"adjoint_dt", MakeValue(adj_d)}});
std::vector<int64_t> perm_value{1, 0};
auto perm = ib->Tensor(perm_value);
auto perm = ib->Value<ShapeVector>({1, 0});
if (adj_d) {
dense_grad = ib->Emit("Transpose", {dense_grad, perm});
}
Expand Down

0 comments on commit 2f58693

Please sign in to comment.