Skip to content

Commit

Permalink
merge interp3d_op_param into interp_op_param
Browse files Browse the repository at this point in the history
merge interp3d_op_param into interp_op_param
  • Loading branch information
wtiandong committed Sep 27, 2022
1 parent e7e3d13 commit 9e28435
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 394 deletions.
359 changes: 49 additions & 310 deletions schema/current/CaffeOp_generated.h

Large diffs are not rendered by default.

62 changes: 8 additions & 54 deletions schema/current/MNN_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -1158,12 +1158,11 @@ enum OpParameter {
OpParameter_LoopParam = 92,
OpParameter_ImageProcessParam = 93,
OpParameter_CumSum = 94,
OpParameter_Interp3D = 95,
OpParameter_MIN = OpParameter_NONE,
OpParameter_MAX = OpParameter_Interp3D
OpParameter_MAX = OpParameter_CumSum
};

inline const OpParameter (&EnumValuesOpParameter())[96] {
inline const OpParameter (&EnumValuesOpParameter())[95] {
static const OpParameter values[] = {
OpParameter_NONE,
OpParameter_QuantizedAdd,
Expand Down Expand Up @@ -1259,8 +1258,7 @@ inline const OpParameter (&EnumValuesOpParameter())[96] {
OpParameter_GridSample,
OpParameter_LoopParam,
OpParameter_ImageProcessParam,
OpParameter_CumSum,
OpParameter_Interp3D
OpParameter_CumSum
};
return values;
}
Expand Down Expand Up @@ -1362,14 +1360,13 @@ inline const char * const *EnumNamesOpParameter() {
"LoopParam",
"ImageProcessParam",
"CumSum",
"Interp3D",
nullptr
};
return names;
}

inline const char *EnumNameOpParameter(OpParameter e) {
if (e < OpParameter_NONE || e > OpParameter_Interp3D) return "";
if (e < OpParameter_NONE || e > OpParameter_CumSum) return "";
const size_t index = static_cast<int>(e);
return EnumNamesOpParameter()[index];
}
Expand Down Expand Up @@ -1754,10 +1751,6 @@ template<> struct OpParameterTraits<CumSum> {
static const OpParameter enum_value = OpParameter_CumSum;
};

template<> struct OpParameterTraits<Interp3D> {
static const OpParameter enum_value = OpParameter_Interp3D;
};

struct OpParameterUnion {
OpParameter type;
void *value;
Expand Down Expand Up @@ -2541,14 +2534,6 @@ struct OpParameterUnion {
return type == OpParameter_CumSum ?
reinterpret_cast<const CumSumT *>(value) : nullptr;
}
Interp3DT *AsInterp3D() {
return type == OpParameter_Interp3D ?
reinterpret_cast<Interp3DT *>(value) : nullptr;
}
const Interp3DT *AsInterp3D() const {
return type == OpParameter_Interp3D ?
reinterpret_cast<const Interp3DT *>(value) : nullptr;
}
};

bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj, OpParameter type);
Expand Down Expand Up @@ -3600,9 +3585,6 @@ struct Op FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const CumSum *main_as_CumSum() const {
return main_type() == OpParameter_CumSum ? static_cast<const CumSum *>(main()) : nullptr;
}
const Interp3D *main_as_Interp3D() const {
return main_type() == OpParameter_Interp3D ? static_cast<const Interp3D *>(main()) : nullptr;
}
const flatbuffers::String *name() const {
return GetPointer<const flatbuffers::String *>(10);
}
Expand Down Expand Up @@ -4011,10 +3993,6 @@ template<> inline const CumSum *Op::main_as<CumSum>() const {
return main_as_CumSum();
}

template<> inline const Interp3D *Op::main_as<Interp3D>() const {
return main_as_Interp3D();
}

struct OpBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
Expand Down Expand Up @@ -5633,10 +5611,6 @@ inline bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj,
auto ptr = reinterpret_cast<const CumSum *>(obj);
return verifier.VerifyTable(ptr);
}
case OpParameter_Interp3D: {
auto ptr = reinterpret_cast<const Interp3D *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
Expand Down Expand Up @@ -6031,10 +6005,6 @@ inline void *OpParameterUnion::UnPack(const void *obj, OpParameter type, const f
auto ptr = reinterpret_cast<const CumSum *>(obj);
return ptr->UnPack(resolver);
}
case OpParameter_Interp3D: {
auto ptr = reinterpret_cast<const Interp3D *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
Expand Down Expand Up @@ -6417,10 +6387,6 @@ inline flatbuffers::Offset<void> OpParameterUnion::Pack(flatbuffers::FlatBufferB
auto ptr = reinterpret_cast<const CumSumT *>(value);
return CreateCumSum(_fbb, ptr, _rehasher).Union();
}
case OpParameter_Interp3D: {
auto ptr = reinterpret_cast<const Interp3DT *>(value);
return CreateInterp3D(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
Expand Down Expand Up @@ -6803,10 +6769,6 @@ inline OpParameterUnion::OpParameterUnion(const OpParameterUnion &u) FLATBUFFERS
value = new CumSumT(*reinterpret_cast<CumSumT *>(u.value));
break;
}
case OpParameter_Interp3D: {
value = new Interp3DT(*reinterpret_cast<Interp3DT *>(u.value));
break;
}
default:
break;
}
Expand Down Expand Up @@ -7284,11 +7246,6 @@ inline void OpParameterUnion::Reset() {
delete ptr;
break;
}
case OpParameter_Interp3D: {
auto ptr = reinterpret_cast<Interp3DT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;
Expand Down Expand Up @@ -7754,8 +7711,7 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
{ flatbuffers::ET_SEQUENCE, 0, 90 },
{ flatbuffers::ET_SEQUENCE, 0, 91 },
{ flatbuffers::ET_SEQUENCE, 0, 92 },
{ flatbuffers::ET_SEQUENCE, 0, 93 },
{ flatbuffers::ET_SEQUENCE, 0, 94 }
{ flatbuffers::ET_SEQUENCE, 0, 93 }
};
static const flatbuffers::TypeFunction type_refs[] = {
QuantizedAddTypeTable,
Expand Down Expand Up @@ -7851,8 +7807,7 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
GridSampleTypeTable,
LoopParamTypeTable,
ImageProcessParamTypeTable,
CumSumTypeTable,
Interp3DTypeTable
CumSumTypeTable
};
static const char * const names[] = {
"NONE",
Expand Down Expand Up @@ -7949,11 +7904,10 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
"GridSample",
"LoopParam",
"ImageProcessParam",
"CumSum",
"Interp3D"
"CumSum"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_UNION, 96, type_codes, type_refs, nullptr, names
flatbuffers::ST_UNION, 95, type_codes, type_refs, nullptr, names
};
return &tt;
}
Expand Down
16 changes: 1 addition & 15 deletions schema/default/CaffeOp.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -320,24 +320,10 @@ table Interp {
heightOffset:float;
cubicCoeffA:float = -0.75;
ctm:CoordinateTransformationMode = NotSet;
}

table Interp3D {
widthScale:float;
heightScale:float;
depthScale:float;
outputWidth:int;
outputHeight:int;
outputDepth:int;
resizeType:int;
alignCorners:bool;
halfPixelCenters:bool = false;
widthOffset:float;
heightOffset:float;
depthOffset:float;
cubicCoeffA:float = -0.75;
ctm:CoordinateTransformationMode = NotSet;
}
}

table Resize {
xScale:float;
Expand Down
3 changes: 1 addition & 2 deletions schema/default/MNN.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,7 @@ union OpParameter {
GridSample,
LoopParam,
ImageProcessParam,
CumSum,
Interp3D
CumSum
}

table Op {
Expand Down
2 changes: 1 addition & 1 deletion source/backend/cpu/CPUInterp3D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class CPUInterp3DCreator : public CPUBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op, Backend *backend) const {
auto interp3D = op->main_as_Interp3D();
auto interp3D = op->main_as_Interp();
return new CPUInterp3D(backend, interp3D->resizeType(),
interp3D->widthScale(), interp3D->heightScale(), interp3D->depthScale(),
interp3D->widthOffset(), interp3D->heightOffset(), interp3D->depthOffset());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ namespace OpenCL {
Interp3DBufExecution::Interp3DBufExecution(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend) : Execution(backend) {
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
auto runtime = mOpenCLBackend->getOpenCLRuntime();
auto interp3DParam = op->main_as_Interp3D();
auto interp3DParam = op->main_as_Interp();
mCordTransform[0] = interp3DParam->widthScale();
mCordTransform[1] = interp3DParam->widthOffset();
mCordTransform[2] = interp3DParam->heightScale();
mCordTransform[3] = interp3DParam->heightOffset();
mCordTransform[4] = interp3DParam->depthScale();
mCordTransform[5] = interp3DParam->depthOffset();
std::set<std::string> buildOptions;
if (op->main_as_Interp3D()->resizeType() == 1) {
if (op->main_as_Interp()->resizeType() == 1) {
mKernelName = "nearest3D_buf";
mKernel = runtime->buildKernel("interp_buf", mKernelName, buildOptions);
} else {
Expand Down
4 changes: 2 additions & 2 deletions source/backend/opencl/execution/image/Interp3DExecution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Interp3DExecution::Interp3DExecution(const std::vector<Tensor *> &inputs, const
: Execution(backend) {
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
auto runtime = mOpenCLBackend->getOpenCLRuntime();
auto interp3DParam = op->main_as_Interp3D();
auto interp3DParam = op->main_as_Interp();
mCordTransform[0] = interp3DParam->widthScale();
mCordTransform[1] = interp3DParam->widthOffset();
mCordTransform[2] = interp3DParam->heightScale();
Expand All @@ -26,7 +26,7 @@ Interp3DExecution::Interp3DExecution(const std::vector<Tensor *> &inputs, const

std::set<std::string> buildOptions;
std::string kernelName = "interp3D";
if (op->main_as_Interp3D()->resizeType() == 1) {
if (op->main_as_Interp()->resizeType() == 1) {
mKernel = runtime->buildKernel("nearest", kernelName, buildOptions);
} else {
MNN_ERROR("Resize types other than nearest are not supported in Interp3D opencl! Using nearest instead\n");
Expand Down
8 changes: 4 additions & 4 deletions source/geometry/GeometryImageOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ static void _ConverterInterp(const Interp* resize, InterpInfo* dstInfo, int inW,
break;
}
}
static void _ConverterInterp3D(const Interp3D* resize, Interp3DInfo* dstInfo, int inW, int inH, int inD, int outW, int outH, int outD, bool computeScale = true) {
static void _ConverterInterp3D(const Interp* resize, Interp3DInfo* dstInfo, int inW, int inH, int inD, int outW, int outH, int outD, bool computeScale = true) {
switch (resize->ctm()) {
case CoordinateTransformationMode_NotSet:
{
Expand Down Expand Up @@ -320,7 +320,7 @@ static flatbuffers::Offset<Op> makeInterp3D(flatbuffers::FlatBufferBuilder& buil
if (nullptr != op->name()) {
temp = builder.CreateString(op->name()->str());
}
Interp3DBuilder intp3DB(builder);
InterpBuilder intp3DB(builder);
intp3DB.add_resizeType(resizeType);
intp3DB.add_widthScale(info->widthScale);
intp3DB.add_heightScale(info->heightScale);
Expand All @@ -332,7 +332,7 @@ static flatbuffers::Offset<Op> makeInterp3D(flatbuffers::FlatBufferBuilder& buil
OpBuilder opB(builder);
opB.add_type(OpType_Interp3D);
opB.add_main(offsetInterp3D);
opB.add_main_type(OpParameter_Interp3D);
opB.add_main_type(OpParameter_Interp);
if (nullptr != op->name()) {
opB.add_name(temp);
}
Expand Down Expand Up @@ -390,7 +390,7 @@ class GeometryImageOp : public GeometryComputer {
res.command.emplace_back(GeometryComputerUtils::makeCommand(builder, {newInputs[0]}, newOutputs));
} else if (OpType_Interp3D == op->type()) {
// Compute cord transform for interp
auto resize = op->main_as_Interp3D();
auto resize = op->main_as_Interp();
auto inShape = inputs[0]->shape();
auto outShape = outputs[0]->shape();
auto inW = inShape[4];
Expand Down
4 changes: 2 additions & 2 deletions source/shape/ShapeInterp3D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class Interp3DComputer : public SizeComputer {
}
}
if (1 == inputSize) {
auto interp3D = op->main_as_Interp3D();
auto interp3D = op->main_as_Interp();
// get output dims
w = interp3D->outputWidth();
h = interp3D->outputHeight();
Expand Down Expand Up @@ -106,7 +106,7 @@ class Interp3DComputer : public SizeComputer {
virtual float onComputeFlops(const MNN::Op* op, const std::vector<Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const override {
auto elementInM = (float)outputs[0]->elementSize() / 1024.0f / 1024.0f;
auto interp3D = op->main_as_Interp3D();
auto interp3D = op->main_as_Interp();
auto unit = 0;
switch (interp3D->resizeType()) {
case 1:
Expand Down
4 changes: 2 additions & 2 deletions tools/converter/source/optimizer/onnxextra/OnnxUpsample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ static EXPRP _transformResize3D(EXPRP expr) {

std::unique_ptr<OpT> mergeredResize3D(new OpT);
mergeredResize3D->type = OpType_Interp3D;
mergeredResize3D->main.type = OpParameter_Interp3D;
mergeredResize3D->main.type = OpParameter_Interp;

std::unique_ptr<Interp3DT> resize3DParam(new Interp3DT);
std::unique_ptr<InterpT> resize3DParam(new InterpT);
// 1:near 2: bilinear 3: cubic
if (resizeMode == "nearest") {
if (nearestMode == "round_prefer_floor") {
Expand Down

0 comments on commit 9e28435

Please sign in to comment.