diff --git a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp index 7bdf46daa87..eb0956c4f8e 100644 --- a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp +++ b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp @@ -30,6 +30,24 @@ bool CircleTransposeConvGraphBuilder::validate(const ValidateArgs &args) const if (args.op.inputs.size() != 3) return false; + const auto &inputs = args.op.inputs; + const auto &tensors = args.reader.tensors(); + const auto &filter_tensor = tensors.at(inputs[1]); + const auto &filter_shape = filter_tensor.get()->shape; + const auto &ifm_tensor = tensors.at(inputs[2]); + const auto &ifm_shape = ifm_tensor.get()->shape; + + // ifm and filters must be 4-D tensor + if (ifm_shape.size() != 4) + return false; + if (filter_shape.size() != 4) + return false; + + // input shape : [batch, height, width, in_channels] + // filters shape : [output_channels, height, weight, in_channels] + if (ifm_tensor.get()->shape.at(3) != filter_tensor.get()->shape.at(3)) + return false; + return true; }