Skip to content

Commit

Permalink
Add MODNet(Matting) model support (PaddlePaddle#76)
Browse files Browse the repository at this point in the history
* Add MODNet(Matting) model support

* from pre_commit.main import main
  • Loading branch information
DefTruth authored Aug 8, 2022
1 parent 9c67653 commit 0e45a7e
Show file tree
Hide file tree
Showing 24 changed files with 990 additions and 14 deletions.
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ cmake-build-debug
cmake-build-release
.vscode
FastDeploy.cmake
fastdeploy/core/config.h
build-debug.sh
*dist
fastdeploy.egg-info
.setuptools-cmake-build
fastdeploy/version.py
fastdeploy/core/config.h
fastdeploy/c_lib_wrap.py
fastdeploy/LICENSE*
fastdeploy/ThirdPartyNotices*
*.so*
fastdeploy/libs/third_libs
csrcs/fastdeploy/core/config.h
csrcs/fastdeploy/core/config.h
csrcs/fastdeploy/pybind/main.cc
5 changes: 4 additions & 1 deletion csrcs/fastdeploy/backends/paddle/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor) {
void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
FDTensor* fd_tensor) {
auto fd_dtype = PaddleDataTypeToFD(tensor->type());
fd_tensor->Allocate(tensor->shape(), fd_dtype, tensor->name());
std::vector<int64_t> shape;
auto tmp_shape = tensor->shape();
shape.assign(tmp_shape.begin(), tmp_shape.end());
fd_tensor->Allocate(shape, fd_dtype, tensor->name());
if (fd_tensor->dtype == FDDataType::FP32) {
tensor->CopyToCpu(static_cast<float*>(fd_tensor->MutableData()));
return;
Expand Down
1 change: 1 addition & 0 deletions csrcs/fastdeploy/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "fastdeploy/vision/wongkinyiu/scaledyolov4.h"
#include "fastdeploy/vision/wongkinyiu/yolor.h"
#include "fastdeploy/vision/wongkinyiu/yolov7.h"
#include "fastdeploy/vision/zhkkke/modnet.h"
#endif

#include "fastdeploy/vision/visualize/visualize.h"
77 changes: 77 additions & 0 deletions csrcs/fastdeploy/vision/common/result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,82 @@ std::string FaceRecognitionResult::Str() {
return out;
}

MattingResult::MattingResult(const MattingResult& res) {
alpha.assign(res.alpha.begin(), res.alpha.end());
foreground.assign(res.foreground.begin(), res.foreground.end());
shape.assign(res.shape.begin(), res.shape.end());
contain_foreground = res.contain_foreground;
}

void MattingResult::Clear() {
std::vector<float>().swap(alpha);
std::vector<float>().swap(foreground);
std::vector<int64_t>().swap(shape);
contain_foreground = false;
}

void MattingResult::Reserve(int size) {
alpha.reserve(size);
if (contain_foreground) {
FDASSERT((shape.size() == 3),
"Please initial shape (h,w,c) before call Reserve.");
int c = static_cast<int>(shape[3]);
foreground.reserve(size * c);
}
}

void MattingResult::Resize(int size) {
alpha.resize(size);
if (contain_foreground) {
FDASSERT((shape.size() == 3),
"Please initial shape (h,w,c) before call Resize.");
int c = static_cast<int>(shape[3]);
foreground.resize(size * c);
}
}

std::string MattingResult::Str() {
std::string out;
out = "MattingResult[";
if (contain_foreground) {
out += "Foreground(true)";
} else {
out += "Foreground(false)";
}
out += ", Alpha(";
size_t numel = alpha.size();
if (numel <= 0) {
return out + "[Empty Result]";
}
// max, min, mean
float min_val = alpha.at(0);
float max_val = alpha.at(0);
float total_val = alpha.at(0);
for (size_t i = 1; i < numel; ++i) {
float val = alpha.at(i);
total_val += val;
if (val < min_val) {
min_val = val;
}
if (val > max_val) {
max_val = val;
}
}
float mean_val = total_val / static_cast<float>(numel);
// shape
std::string shape_str = "Shape(";
for (size_t i = 0; i < shape.size(); ++i) {
if ((i + 1) != shape.size()) {
shape_str += std::to_string(shape[i]) + ",";
} else {
shape_str += std::to_string(shape[i]) + ")";
}
}
out = out + "Numel(" + std::to_string(numel) + "), " + shape_str + ", Min(" +
std::to_string(min_val) + "), " + "Max(" + std::to_string(max_val) +
"), " + "Mean(" + std::to_string(mean_val) + "))]\n";
return out;
}

} // namespace vision
} // namespace fastdeploy
27 changes: 26 additions & 1 deletion csrcs/fastdeploy/vision/common/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ enum FASTDEPLOY_DECL ResultType {
DETECTION,
SEGMENTATION,
FACE_DETECTION,
FACE_RECOGNITION
FACE_RECOGNITION,
MATTING
};

struct FASTDEPLOY_DECL BaseResult {
Expand Down Expand Up @@ -119,5 +120,29 @@ struct FASTDEPLOY_DECL FaceRecognitionResult : public BaseResult {
std::string Str();
};

struct FASTDEPLOY_DECL MattingResult : public BaseResult {
// alpha matte and fgr (predicted foreground: HWC/BGR float32)
std::vector<float> alpha; // h x w
std::vector<float> foreground; // h x w x c (c=3 default)
// height, width, channel for foreground and alpha
// must be (h,w,c) and setup before Reserve and Resize
// c is only for foreground if contain_foreground is true.
std::vector<int64_t> shape;
bool contain_foreground = false;

ResultType type = ResultType::MATTING;

MattingResult() {}
MattingResult(const MattingResult& res);

void Clear();

void Reserve(int size);

void Resize(int size);

std::string Str();
};

} // namespace vision
} // namespace fastdeploy
11 changes: 11 additions & 0 deletions csrcs/fastdeploy/vision/vision_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void BindLinzaer(pybind11::module& m);
void BindBiubug6(pybind11::module& m);
void BindPpogg(pybind11::module& m);
void BindDeepInsight(pybind11::module& m);
void BindZHKKKe(pybind11::module& m);
#ifdef ENABLE_VISION_VISUALIZE
void BindVisualize(pybind11::module& m);
#endif
Expand Down Expand Up @@ -74,6 +75,15 @@ void BindVision(pybind11::module& m) {
.def("__repr__", &vision::FaceRecognitionResult::Str)
.def("__str__", &vision::FaceRecognitionResult::Str);

pybind11::class_<vision::MattingResult>(m, "MattingResult")
.def(pybind11::init())
.def_readwrite("alpha", &vision::MattingResult::alpha)
.def_readwrite("foreground", &vision::MattingResult::foreground)
.def_readwrite("shape", &vision::MattingResult::shape)
.def_readwrite("contain_foreground", &vision::MattingResult::shape)
.def("__repr__", &vision::MattingResult::Str)
.def("__str__", &vision::MattingResult::Str);

BindPPCls(m);
BindPPDet(m);
BindPPSeg(m);
Expand All @@ -87,6 +97,7 @@ void BindVision(pybind11::module& m) {
BindBiubug6(m);
BindPpogg(m);
BindDeepInsight(m);
BindZHKKKe(m);
#ifdef ENABLE_VISION_VISUALIZE
BindVisualize(m);
#endif
Expand Down
123 changes: 123 additions & 0 deletions csrcs/fastdeploy/vision/visualize/matting_alpha.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef ENABLE_VISION_VISUALIZE

#include "fastdeploy/vision/visualize/visualize.h"
#include "opencv2/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"

namespace fastdeploy {
namespace vision {

static void RemoveSmallConnectedArea(cv::Mat* alpha_pred,
float threshold = 0.05f) {
// 移除小的联通区域和噪点 开闭合形态学处理
// 假设输入的是透明度alpha, 值域(0.,1.)
cv::Mat gray, binary;
(*alpha_pred).convertTo(gray, CV_8UC1, 255.f);
// 255 * 0.05 ~ 13
unsigned int binary_threshold = static_cast<unsigned int>(255.f * threshold);
cv::threshold(gray, binary, binary_threshold, 255, cv::THRESH_BINARY);
// morphologyEx with OPEN operation to remove noise first.
auto kernel = cv::getStructuringElement(cv::MORPH_ELLIPSE, cv::Size(3, 3),
cv::Point(-1, -1));
cv::morphologyEx(binary, binary, cv::MORPH_OPEN, kernel);
// Computationally connected domain
cv::Mat labels = cv::Mat::zeros((*alpha_pred).size(), CV_32S);
cv::Mat stats, centroids;
int num_labels =
cv::connectedComponentsWithStats(binary, labels, stats, centroids, 8, 4);
if (num_labels <= 1) {
// no noise, skip.
return;
}
// find max connected area, 0 is background
int max_connected_id = 1; // 1,2,...
int max_connected_area = stats.at<int>(max_connected_id, cv::CC_STAT_AREA);
for (int i = 1; i < num_labels; ++i) {
int tmp_connected_area = stats.at<int>(i, cv::CC_STAT_AREA);
if (tmp_connected_area > max_connected_area) {
max_connected_area = tmp_connected_area;
max_connected_id = i;
}
}
const int h = (*alpha_pred).rows;
const int w = (*alpha_pred).cols;
// remove small connected area.
for (int i = 0; i < h; ++i) {
int* label_row_ptr = labels.ptr<int>(i);
float* alpha_row_ptr = (*alpha_pred).ptr<float>(i);
for (int j = 0; j < w; ++j) {
if (label_row_ptr[j] != max_connected_id) alpha_row_ptr[j] = 0.f;
}
}
}

void Visualize::VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
cv::Mat* vis_img,
bool remove_small_connected_area) {
// 只可视化alpha,fgr(前景)本身就是一张图 不需要可视化
FDASSERT((!im.empty()), "im can't be empty!");
FDASSERT((im.channels() == 3), "Only support 3 channels mat!");
int out_h = static_cast<int>(result.shape[0]);
int out_w = static_cast<int>(result.shape[1]);
int height = im.rows;
int width = im.cols;
// alpha to cv::Mat && 避免resize等操作修改外部数据
std::vector<float> alpha_copy;
alpha_copy.assign(result.alpha.begin(), result.alpha.end());
float* alpha_ptr = static_cast<float*>(alpha_copy.data());
cv::Mat alpha(out_h, out_w, CV_32FC1, alpha_ptr);
if (remove_small_connected_area) {
RemoveSmallConnectedArea(&alpha, 0.05f);
}
if ((out_h != height) || (out_w != width)) {
cv::resize(alpha, alpha, cv::Size(width, height));
}

int vis_h = (*vis_img).rows;
int vis_w = (*vis_img).cols;

if ((vis_h != height) || (vis_w != width)) {
// faster than resize
(*vis_img) = cv::Mat::zeros(height, width, CV_8UC3);
}
if ((*vis_img).type() != CV_8UC3) {
(*vis_img).convertTo((*vis_img), CV_8UC3);
}

uchar* vis_data = static_cast<uchar*>(vis_img->data);
uchar* im_data = static_cast<uchar*>(im.data);
float* alpha_data = reinterpret_cast<float*>(alpha.data);

for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < width; ++j) {
float alpha_val = alpha_data[i * width + j];
vis_data[i * width * 3 + j * 3 + 0] = cv::saturate_cast<uchar>(
static_cast<float>(im_data[i * width * 3 + j * 3 + 0]) * alpha_val +
(1.f - alpha_val) * 153.f);
vis_data[i * width * 3 + j * 3 + 1] = cv::saturate_cast<uchar>(
static_cast<float>(im_data[i * width * 3 + j * 3 + 1]) * alpha_val +
(1.f - alpha_val) * 255.f);
vis_data[i * width * 3 + j * 3 + 2] = cv::saturate_cast<uchar>(
static_cast<float>(im_data[i * width * 3 + j * 3 + 2]) * alpha_val +
(1.f - alpha_val) * 120.f);
}
}
}

} // namespace vision
} // namespace fastdeploy
#endif
3 changes: 3 additions & 0 deletions csrcs/fastdeploy/vision/visualize/visualize.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class FASTDEPLOY_DECL Visualize {
static void VisSegmentation(const cv::Mat& im,
const SegmentationResult& result,
cv::Mat* vis_img, const int& num_classes = 1000);
static void VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
cv::Mat* vis_img,
bool remove_small_connected_area = false);
};

} // namespace vision
Expand Down
26 changes: 18 additions & 8 deletions csrcs/fastdeploy/vision/visualize/visualize_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,23 @@ void BindVisualize(pybind11::module& m) {
vision::Visualize::VisFaceDetection(&im, result, line_size,
font_size);
})
.def_static("vis_segmentation", [](pybind11::array& im_data,
vision::SegmentationResult& result,
pybind11::array& vis_im_data,
const int& num_classes) {
cv::Mat im = PyArrayToCvMat(im_data);
cv::Mat vis_im = PyArrayToCvMat(vis_im_data);
vision::Visualize::VisSegmentation(im, result, &vis_im, num_classes);
});
.def_static(
"vis_segmentation",
[](pybind11::array& im_data, vision::SegmentationResult& result,
pybind11::array& vis_im_data, const int& num_classes) {
cv::Mat im = PyArrayToCvMat(im_data);
cv::Mat vis_im = PyArrayToCvMat(vis_im_data);
vision::Visualize::VisSegmentation(im, result, &vis_im,
num_classes);
})
.def_static(
"vis_matting_alpha",
[](pybind11::array& im_data, vision::MattingResult& result,
pybind11::array& vis_im_data, bool remove_small_connected_area) {
cv::Mat im = PyArrayToCvMat(im_data);
cv::Mat vis_im = PyArrayToCvMat(vis_im_data);
vision::Visualize::VisMattingAlpha(im, result, &vis_im,
remove_small_connected_area);
});
}
} // namespace fastdeploy
Loading

0 comments on commit 0e45a7e

Please sign in to comment.