Skip to content

Commit

Permalink
[Converter:Bugfix] Support Onnx::TopK for dynamic shape
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaying committed Dec 10, 2021
1 parent a14ef5e commit b3c5fee
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 55 deletions.
26 changes: 21 additions & 5 deletions schema/current/TensorflowOp_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -1499,9 +1499,11 @@ struct TopKV2T : public flatbuffers::NativeTable {
typedef TopKV2 TableType;
DataType T;
bool sorted;
bool largest;
TopKV2T()
: T(DataType_DT_FLOAT),
sorted(false) {
sorted(false),
largest(true) {
}
};

Expand All @@ -1516,10 +1518,14 @@ struct TopKV2 FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool sorted() const {
return GetField<uint8_t>(6, 0) != 0;
}
bool largest() const {
return GetField<uint8_t>(8, 1) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, 4) &&
VerifyField<uint8_t>(verifier, 6) &&
VerifyField<uint8_t>(verifier, 8) &&
verifier.EndTable();
}
TopKV2T *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
Expand All @@ -1536,6 +1542,9 @@ struct TopKV2Builder {
void add_sorted(bool sorted) {
fbb_.AddElement<uint8_t>(6, static_cast<uint8_t>(sorted), 0);
}
void add_largest(bool largest) {
fbb_.AddElement<uint8_t>(8, static_cast<uint8_t>(largest), 1);
}
explicit TopKV2Builder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
Expand All @@ -1551,9 +1560,11 @@ struct TopKV2Builder {
inline flatbuffers::Offset<TopKV2> CreateTopKV2(
flatbuffers::FlatBufferBuilder &_fbb,
DataType T = DataType_DT_FLOAT,
bool sorted = false) {
bool sorted = false,
bool largest = true) {
TopKV2Builder builder_(_fbb);
builder_.add_T(T);
builder_.add_largest(largest);
builder_.add_sorted(sorted);
return builder_.Finish();
}
Expand Down Expand Up @@ -3806,6 +3817,7 @@ inline void TopKV2::UnPackTo(TopKV2T *_o, const flatbuffers::resolver_function_t
(void)_resolver;
{ auto _e = T(); _o->T = _e; };
{ auto _e = sorted(); _o->sorted = _e; };
{ auto _e = largest(); _o->largest = _e; };
}

inline flatbuffers::Offset<TopKV2> TopKV2::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TopKV2T* _o, const flatbuffers::rehasher_function_t *_rehasher) {
Expand All @@ -3818,10 +3830,12 @@ inline flatbuffers::Offset<TopKV2> CreateTopKV2(flatbuffers::FlatBufferBuilder &
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TopKV2T* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _T = _o->T;
auto _sorted = _o->sorted;
auto _largest = _o->largest;
return MNN::CreateTopKV2(
_fbb,
_T,
_sorted);
_sorted,
_largest);
}

inline CropAndResizeT *CropAndResize::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
Expand Down Expand Up @@ -5078,17 +5092,19 @@ inline const flatbuffers::TypeTable *UnaryOpTypeTable() {
inline const flatbuffers::TypeTable *TopKV2TypeTable() {
static const flatbuffers::TypeCode type_codes[] = {
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_BOOL, 0, -1 },
{ flatbuffers::ET_BOOL, 0, -1 }
};
static const flatbuffers::TypeFunction type_refs[] = {
DataTypeTypeTable
};
static const char * const names[] = {
"T",
"sorted"
"sorted",
"largest"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_TABLE, 2, type_codes, type_refs, nullptr, names
flatbuffers::ST_TABLE, 3, type_codes, type_refs, nullptr, names
};
return &tt;
}
Expand Down
1 change: 1 addition & 0 deletions schema/default/TensorflowOp.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ table UnaryOp {
table TopKV2 {
T:DataType=DT_FLOAT;
sorted:bool=false;
largest:bool=true;
}
enum CropAndResizeMethod : byte{
BILINEAR=0,
Expand Down
31 changes: 21 additions & 10 deletions source/backend/cpu/CPUTopKV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ template <typename T>
class TopContainer {
public:
TopContainer() = delete;
TopContainer(int32_t k, int32_t rowSize) : mK(k) {
TopContainer(int32_t k, int32_t rowSize, bool largest) : mK(k), mLargest(largest) {
mContainer.reserve(std::min(k, rowSize) + 1);
}

Expand All @@ -27,7 +27,12 @@ class TopContainer {
mContainer.clear();
}
void push(int32_t a) {
auto comparator = [this](int32_t a, int32_t b) { return compareFunc(a, b); };
std::function<bool(int32_t, int32_t)> comparator;
if (mLargest) {
comparator = [this](int32_t a, int32_t b) { return compareFunc(a, b); };
} else {
comparator = [this](int32_t a, int32_t b) { return !compareFunc(a, b); };
}
if (mContainer.size() <= mK) {
mContainer.push_back(a);
if (mContainer.size() == mK + 1) {
Expand All @@ -42,7 +47,12 @@ class TopContainer {
}

const std::vector<int32_t>& sortedResult() {
auto comparator = [this](int32_t a, int32_t b) { return compareFunc(a, b); };
std::function<bool(int32_t, int32_t)> comparator;
if (mLargest) {
comparator = [this](int32_t a, int32_t b) { return compareFunc(a, b); };
} else {
comparator = [this](int32_t a, int32_t b) { return !compareFunc(a, b); };
}
if (mContainer.size() <= mK) {
std::sort(mContainer.begin(), mContainer.end(), comparator);
} else {
Expand All @@ -54,6 +64,7 @@ class TopContainer {

private:
int32_t mK;
bool mLargest;
std::vector<int32_t> mContainer;
const T* mValues = nullptr;

Expand All @@ -69,8 +80,8 @@ class TopContainer {
};

template <typename T>
void findTopK(int32_t rowSize, int32_t numRows, const T* data, int32_t k, int32_t* outputIndexes, T* outputValues) {
TopContainer<T> topc(k, rowSize);
void findTopK(int32_t rowSize, int32_t numRows, const T* data, int32_t k, int32_t* outputIndexes, T* outputValues, bool largest) {
TopContainer<T> topc(k, rowSize, largest);
for (int row = 0; row < numRows; row++) {
const T* valuesRow = data + row * rowSize;
topc.startCollecting(valuesRow);
Expand All @@ -87,7 +98,7 @@ void findTopK(int32_t rowSize, int32_t numRows, const T* data, int32_t k, int32_
}
}

CPUTopKV2::CPUTopKV2(Backend* b) : MNN::Execution(b) {
CPUTopKV2::CPUTopKV2(Backend* b, const Op* op) : MNN::Execution(b), mLargest(op->main_as_TopKV2()->largest()) {
// nothing to do
}

Expand All @@ -106,7 +117,7 @@ ErrorCode CPUTopKV2::onExecute(const std::vector<Tensor*>& inputs, const std::ve
MNN_ASSERT(k <= rowSize);
const int numRows = inputTensor->elementSize() / rowSize;

if (k == 1) {
if (k == 1 && mLargest) {
if (halide_type_float == inputTensor->getType().code) {
float* inputData = inputTensor->host<float>();
float* topkData = outputData->host<float>();
Expand Down Expand Up @@ -158,12 +169,12 @@ ErrorCode CPUTopKV2::onExecute(const std::vector<Tensor*>& inputs, const std::ve
auto inputData = inputTensor->host<float>();
auto topkData = outputData->host<float>();
int* indicesData = outputIndices->host<int32_t>();
findTopK<float>(rowSize, numRows, inputData, k, indicesData, topkData);
findTopK<float>(rowSize, numRows, inputData, k, indicesData, topkData, mLargest);
} else if(halide_type_int == inputTensor->getType().code && 32 == inputTensor->getType().bits) {
auto inputData = inputTensor->host<int32_t>();
auto topkData = outputData->host<int32_t>();
int* indicesData = outputIndices->host<int32_t>();
findTopK<int32_t>(rowSize, numRows, inputData, k, indicesData, topkData);
findTopK<int32_t>(rowSize, numRows, inputData, k, indicesData, topkData, mLargest);
} else {
MNN_PRINT("TODO\n");
MNN_ASSERT(false);
Expand All @@ -175,7 +186,7 @@ class CPUTopKV2Creator : public CPUBackend::Creator {
public:
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op, Backend* backend) const override {
return new CPUTopKV2(backend);
return new CPUTopKV2(backend, op);
}
};

Expand Down
3 changes: 2 additions & 1 deletion source/backend/cpu/CPUTopKV2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
namespace MNN {
class CPUTopKV2 : public Execution {
public:
CPUTopKV2(Backend *b);
CPUTopKV2(Backend *b, const Op* op);
virtual ~CPUTopKV2() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;

private:
bool mLargest;
};
} // namespace MNN

Expand Down
2 changes: 2 additions & 0 deletions source/geometry/GeometryOPRegister.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ extern void ___GeometryTile___create__();
extern void ___GeometryReshape___create__();
extern void ___GeometryReduce___create__();
extern void ___GeometryInnerProduct___create__();
extern void ___GeometryTopK___create__();
extern void ___GeometryDepthToSpace___create__();
extern void ___GeometryBroadcastTo___create__();
extern void ___GeometryConvert___create__();
Expand Down Expand Up @@ -44,6 +45,7 @@ ___GeometryTile___create__();
___GeometryReshape___create__();
___GeometryReduce___create__();
___GeometryInnerProduct___create__();
___GeometryTopK___create__();
___GeometryDepthToSpace___create__();
___GeometryBroadcastTo___create__();
___GeometryConvert___create__();
Expand Down
97 changes: 97 additions & 0 deletions source/geometry/GeometryTopK.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
//
// GeometryTopK.cpp
// MNN
//
// Created by MNN on 2020/06/09.
// Copyright © 2018, Alibaba Group Holding Limited
//

#include <numeric>
#include "geometry/GeometryComputer.hpp"
#include "geometry/GeometryComputerUtils.hpp"
#include "core/OpCommonUtils.hpp"
namespace MNN {
class GeometryTopK : public GeometryComputer {
public:
virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
Context& context, CommandBuffer& res) const override {
if (outputs.size() != 2 || inputs.size() < 2 || inputs.size() > 3) {
MNN_ERROR("TopK should have 2 output and 2~3 input, get %lu in and %lu out\n", inputs.size(), outputs.size());
return false;
}
if (inputs.size() == 2) {
SharedPtr<Command> cmdP(new Command);
auto& cmd = *cmdP;
cmd.op = op;
cmd.inputs = std::move(inputs);
cmd.outputs = std::move(outputs);
res.command.emplace_back(std::move(cmdP));
return true;
}
if (inputs[1]->host<int32_t>() == nullptr || inputs[2]->host<int32_t>() == nullptr) {
MNN_ERROR("Invalid k or axis\n");
return false;
}
int k = inputs[1]->host<int32_t>()[0], axis = inputs[2]->host<int32_t>()[0];
auto shape = inputs[0]->shape();
int outside = std::accumulate(shape.begin(), shape.begin() + axis, 1, [](int a, int b) { return a * b; });
int inside = std::accumulate(shape.begin() + axis + 1, shape.end(), 1, [](int a, int b) { return a * b; });
if (axis < 0) {
axis = axis + shape.size();
}
std::shared_ptr<Tensor> transInput, transVal, transInd;
{ // transpose TopK's axis to last axis
transInput.reset(Tensor::createDevice({outside * inside, shape[axis]}, inputs[0]->getType(), inputs[0]->getDimensionType()));
auto des = TensorUtils::getDescribe(transInput.get());
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
Tensor::InsideDescribe::Region reg;
reg.origin = inputs[0];
reg.src.stride[0] = reg.dst.stride[0] = inside * shape[axis];
reg.src.stride[2] = inside;
reg.dst.stride[1] = shape[axis];
reg.size[0] = outside;
reg.size[1] = inside;
reg.size[2] = shape[axis];
des->regions.assign({reg});
res.extras.emplace_back(transInput);
}
{ // transpose TopK's axis from last axis
transVal.reset(Tensor::createDevice({outside * inside, k}, outputs[0]->getType(), outputs[0]->getDimensionType()));
transInd.reset(Tensor::createDevice({outside * inside, k}, outputs[1]->getType(), outputs[1]->getDimensionType()));
Tensor::InsideDescribe::Region reg;
reg.src.stride[0] = reg.dst.stride[0] = inside * k;
reg.src.stride[2] = k;
reg.dst.stride[1] = inside;
reg.size[0] = outside;
reg.size[1] = k;
reg.size[2] = inside;
auto des = TensorUtils::getDescribe(outputs[0]);
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
reg.origin = transVal.get();
des->regions.assign({reg});
res.extras.emplace_back(transVal);
des = TensorUtils::getDescribe(outputs[1]);
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
reg.origin = transInd.get();
des->regions.assign({reg});
res.extras.emplace_back(transInd);
}
{ // do TopK on last axis
SharedPtr<Command> cmdP(new Command);
auto& cmd = *cmdP;
cmd.op = op;
cmd.inputs.assign({transInput.get(), inputs[1]});
cmd.outputs.assign({transVal.get(), transInd.get()});
res.command.emplace_back(std::move(cmdP));
}
return true;
}
};
static void _create() {
std::shared_ptr<GeometryComputer> comp(new GeometryTopK);
GeometryComputer::registerGeometryComputer(comp, {OpType_TopKV2});
}

REGISTER_GEOMETRY(GeometryTopK, _create);

} // namespace MNN
Loading

0 comments on commit b3c5fee

Please sign in to comment.