diff --git a/tflite/kernels/internal/reference/integer_ops/fully_connected.h b/tflite/kernels/internal/reference/integer_ops/fully_connected.h index cbcd54b2..dc700c2a 100644 --- a/tflite/kernels/internal/reference/integer_ops/fully_connected.h +++ b/tflite/kernels/internal/reference/integer_ops/fully_connected.h @@ -42,12 +42,13 @@ void FullyConnectedPerChannel( const int32_t output_activation_min = params.quantized_activation_min; const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); - TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); const int filter_dim_count = filter_shape.DimensionsCount(); - const int batches = output_shape.Dims(0); - const int output_depth = output_shape.Dims(1); + const int output_dim_count = output_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = output_shape.Dims(output_dim_count - 1); TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2)); const int accum_depth = filter_shape.Dims(filter_dim_count - 1); for (int b = 0; b < batches; ++b) {