Skip to content

Commit

Permalink
[Relay][Topi][Op]Advanced indexing (apache#6388)
Browse files Browse the repository at this point in the history
* Add Relay adv_index op

* Support single index tensor dynamic shape

* Support more dynamic index

* Fix lint

* Minor fix for comment

* Fix lint

* Fix lint

* Fix test

* Fix
  • Loading branch information
kevinthesun authored Sep 11, 2020
1 parent 355720e commit 1228111
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 38 deletions.
80 changes: 80 additions & 0 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/te/operation.h>
#include <tvm/tir/data_layout.h>
#include <tvm/topi/broadcast.h>
#include <tvm/topi/detail/constant_utils.h>
#include <tvm/topi/detail/ravel_unravel.h>
#include <tvm/topi/detail/tensor_utils.h>
Expand Down Expand Up @@ -1551,6 +1552,85 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal,
name, tag);
}

/*!
* \brief Numpy style advanced indexing with tensor.
* \param data is input data.
* \param indices is list of indexing tensors.
* \param name output tensor name.
* \param tag output tensor tag.
* \return Output tensor.
*/
inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
const std::string name = "advanced_index",
const std::string tag = kInjective) {
Array<PrimExpr> oshape;
Array<PrimExpr> broadcast_shape;
Array<Tensor> bindices;
std::vector<int64_t> flatten_shape_lens;
int64_t num_picked_elems = 1;
bool has_dyn_shape = false;

if (indices.size() == 1) {
broadcast_shape = indices[0]->shape;
bindices = indices;
} else {
for (const auto& index : indices) {
int64_t flatten_len = 1;
for (const auto& dim : index->shape) {
const IntImmNode* axis_len = dim.as<IntImmNode>();
if (!axis_len) {
broadcast_shape = index->shape;
has_dyn_shape = true;
break;
}
flatten_len *= axis_len->value;
}
if (has_dyn_shape) break;
flatten_shape_lens.push_back(flatten_len);
if (flatten_len > num_picked_elems) {
num_picked_elems = flatten_len;
broadcast_shape = index->shape;
}
}

// Do broadcast for indices
for (size_t i = 0; i < indices.size(); ++i) {
if (!has_dyn_shape && flatten_shape_lens[i] < num_picked_elems) {
bindices.push_back(broadcast_to(indices[i], broadcast_shape));
} else {
bindices.push_back(indices[i]);
}
}
}

for (const auto& dim : broadcast_shape) {
oshape.push_back(dim);
}
for (size_t i = indices.size(); i < data->shape.size(); ++i) {
oshape.push_back(data->shape[i]);
}

return compute(
oshape,
[&](const Array<Var>& iter_var) {
Array<PrimExpr> tensor_indices;
for (size_t i = 0; i < broadcast_shape.size(); ++i) {
tensor_indices.push_back(iter_var[i]);
}

Array<PrimExpr> real_indices;
for (size_t i = 0; i < bindices.size(); ++i) {
real_indices.push_back(bindices[i](tensor_indices));
}
for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
real_indices.push_back(iter_var[i]);
}

return data(real_indices);
},
name, tag);
}

} // namespace topi
} // namespace tvm
#endif // TVM_TOPI_TRANSFORM_H_
40 changes: 2 additions & 38 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1816,44 +1816,8 @@ def _impl(inputs, input_types):
def _index():
def _impl(inputs, input_types):
data = inputs[0]
indices = []
raw_indices = []
max_indices_len = -1
for index in inputs[1]:
if not isinstance(index, _expr.Constant):
try:
index = _expr.const(_infer_value(index, {}))
except Exception:
raise RuntimeError("Only supports constant indices for "
"pytorch advanced indexing ")
raw_indices.append(index)
cindex_len = index.data.shape[0]
if cindex_len > max_indices_len:
max_indices_len = cindex_len

for index in raw_indices:
cnp = index.data.asnumpy()
cindex_len = cnp.shape[0]
if cindex_len < max_indices_len:
cnp = np.tile(cnp, max_indices_len // cindex_len)
indices.append(cnp)

ret = []
slice_map = {}
for i in range(indices[0].shape[0]):
tmp = data
current_indices = []
for index in indices:
current_indices.append(index[i])
index_key = tuple(current_indices)
if index_key in slice_map:
tmp = slice_map[index_key]
else:
tmp = _op.take(tmp, _expr.const(index[i]), axis=0)
slice_map[index_key] = tmp
ret.append(_op.expand_dims(tmp, axis=0))

return _op.concatenate(ret, axis=0)
indices = inputs[1]
return _op.adv_index([data] + indices)
return _impl


Expand Down
33 changes: 33 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
_reg.register_injective_schedule("unravel_index")
_reg.register_injective_schedule("sparse_to_dense")
_reg.register_injective_schedule("matrix_set_diag")
_reg.register_injective_schedule("adv_index")

# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
Expand Down Expand Up @@ -661,3 +662,35 @@ def split_shape_func(attrs, inputs, _):
convert(i),
convert(indices_or_sections),
convert(axis)) for i in range(num_out)]

@script
def _adv_index_shape_func(inputs):
index_rank = inputs[1].shape[0]
data_rank = inputs[0].shape[0]
out = output_tensor((data_rank + index_rank - len(inputs) + 1,), "int64")

max_flatten_len = int64(1)
for i in const_range(index_rank):
max_flatten_len *= inputs[1][i]
out[i] = inputs[1][i]
for i in const_range(len(inputs) - 2):
flatten_len = int64(1)
for j in const_range(index_rank):
flatten_len *= inputs[i + 2][j]
if flatten_len > max_flatten_len:
max_flatten_len = flatten_len
for k in const_range(index_rank):
out[k] = inputs[i + 2][k]

for i in const_range(data_rank - len(inputs) + 1):
out[i + index_rank] = inputs[0][i + len(inputs) - 1]

return out

@_reg.register_shape_func("adv_index", False)
def adv_index_shape_func(attrs, inputs, _):
"""
Shape func for adv_index.
Only allow single index tensor.
"""
return [_adv_index_shape_func(inputs)]
18 changes: 18 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,3 +1213,21 @@ def matrix_set_diag(data, diagonal):
[7, 7, 6, 7]]]
"""
return _make.matrix_set_diag(data, diagonal)


def adv_index(inputs):
"""
Numpy style advanced indexing. Index with a list of tensors.
Parameters
----------
inputs : Union(List[relay.Expr], Tuple[relay.Expr])
Input tensor and indices.
The first tensor is input data and rests are indices.
Returns
-------
result: relay.Expr
Output tensor.
"""
return _make.adv_index(Tuple(inputs))
18 changes: 18 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,3 +838,21 @@ def matrix_set_diag(data, diagonal):
[7, 7, 6, 7]]]
"""
return cpp.matrix_set_diag(data, diagonal)

def adv_index(data, indices):
"""Numpy style indexing with tensors.
Parameters
----------
data : tvm.te.Tensor
Input data.
indices : A list of tvm.te.Tensor
Tensor index.
Returns
-------
result : tvm.te.Tensor
Output tensor
"""
return cpp.adv_index(data, indices)
83 changes: 83 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3163,5 +3163,88 @@ RELAY_REGISTER_OP("matrix_set_diag")
.set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// adv_index
bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto inputs = types[0].as<TupleTypeNode>();
auto data = inputs->fields[0].as<TensorTypeNode>();

if (inputs == nullptr || data == nullptr) {
return false;
}

Array<IndexExpr> oshape;
Array<IndexExpr> broadcast_shape;
int64_t num_picked_elems = 1;

if (inputs->fields.size() == 2) {
broadcast_shape = inputs->fields[1].as<TensorTypeNode>()->shape;
} else {
for (size_t i = 1; i < inputs->fields.size(); ++i) {
auto index_type = inputs->fields[i].as<TensorTypeNode>();
if (index_type == nullptr) {
return false;
}
CHECK(index_type->dtype.is_int()) << "indices must be tensor of integers";

int64_t flatten_len = 1;
bool has_dyn_shape = false;
for (const auto& dim : index_type->shape) {
const IntImmNode* axis_len = dim.as<IntImmNode>();
if (!axis_len) {
// If dynamic shape appears, just use the first shape
broadcast_shape = index_type->shape;
has_dyn_shape = true;
break;
}
flatten_len *= axis_len->value;
}
if (has_dyn_shape) break;
if (flatten_len > num_picked_elems) {
num_picked_elems = flatten_len;
broadcast_shape = index_type->shape;
}
}
}

for (const auto& dim : broadcast_shape) {
oshape.push_back(dim);
}
for (size_t i = inputs->fields.size() - 1; i < data->shape.size(); ++i) {
oshape.push_back(data->shape[i]);
}
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}

Array<te::Tensor> AdvIndexCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
Array<te::Tensor> indices;
for (size_t i = 1; i < inputs.size(); ++i) {
indices.push_back(inputs[i]);
}
return {topi::adv_index(inputs[0], indices)};
}

Expr MakeAdvIndex(Expr inputs) {
static const Op& op = Op::Get("adv_index");
return Call(op, {inputs}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.adv_index").set_body_typed(MakeAdvIndex);

RELAY_REGISTER_OP("adv_index")
.describe(R"code(Numpy style advanced indexing. Index with a list of tensors.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(3)
.add_argument("inputs", "Tuple of Tensors", "Input tensor and indices.")
.add_type_rel("AdvIndex", AdvIndexRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", AdvIndexCompute);

} // namespace relay
} // namespace tvm
4 changes: 4 additions & 0 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,9 @@ TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValu
*rv = matrix_set_diag(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = adv_index(args[0], args[1]);
});

} // namespace topi
} // namespace tvm
15 changes: 15 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,5 +892,20 @@ def test_reshape_concat():
np.reshape(np_data1, np_shape_like1.shape)], axis=0)
check_result([np_data0, np_data1, np_shape_like0, np_shape_like1], mod, ref_res)

def test_any_adv_index():
data = relay.var("data", shape=(5, relay.Any(), relay.Any()), dtype='float32')
index0 = relay.var("index0", shape=(1, relay.Any()), dtype='int64')
index1 = relay.var("index1", shape=(1, relay.Any()), dtype='int64')
out = relay.adv_index([data, index0, index1])
mod = tvm.IRModule()
mod['main'] = relay.Function([data, index0, index1], out)
np_data_shape = (5, 5, 10)
np_index_shape = (1, 4)
np_data = np.random.uniform(size=np_data_shape).astype('float32')
np_index = np.random.uniform(0, np_data_shape[0], size=np_index_shape).astype('int64')
ref_res = np_data[tuple([np_index, np_index])]
check_result([np_data, np_index, np_index], mod, ref_res)


if __name__ == "__main__":
pytest.main([__file__])
26 changes: 26 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,31 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_
#sparse_indices should not be > 2d tensor
#verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])

def test_adv_index():
def verify_adv_index(data_shape, index_shapes):
dtype = "float32"
inputs = [relay.var("data", relay.TensorType(data_shape, dtype))]
np_data = np.random.uniform(size=data_shape).astype(dtype)
np_indices = []
for i, index_shape in enumerate(index_shapes):
limit = data_shape[i]
np_indices.append(np.random.uniform(0, limit - 1, size=index_shape).astype("int64"))
inputs.append(relay.var("index_{}".format(i), relay.TensorType(index_shape, "int64")))
np_out = np_data[tuple(np_indices)]
np_args = [np_data] + np_indices
out = relay.op.adv_index(inputs)

func = relay.Function(inputs, out)
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(*np_args)
tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=1e-5)

verify_adv_index((10, 5), [(3, 4), (3, 1)])
verify_adv_index((10, 5), [(2,),])
verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])

if __name__ == "__main__":
test_cast()
test_zeros_ones()
Expand Down Expand Up @@ -1127,3 +1152,4 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_
test_unravel_index()
test_sparse_to_dense()
test_fixed_point_multiply()
test_adv_index()
Loading

0 comments on commit 1228111

Please sign in to comment.