forked from flong998/badouai-ai-special-2024
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1579 from lzc2github/main
伦泽成第15周作业Create maskrcnn
- Loading branch information
Showing
1 changed file
with
65 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
import torchvision | ||
from torchvision.models.detection import maskrcnn_resnet50_fpn | ||
from torchvision.transforms import functional as F | ||
from PIL import Image,ImageDraw | ||
import numpy as np | ||
import cv2 | ||
|
||
# 加载预训练模型 | ||
model = maskrcnn_resnet50_fpn(pretrained=True) | ||
model.eval() | ||
|
||
# 如果你的模型是在GPU上训练的,确保模型也在GPU上进行推理 | ||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
model = model.to(device) | ||
|
||
# 加载图像并进行预处理 | ||
def preprocess_image(image): | ||
transform = torchvision.transforms.Compose([ | ||
torchvision.transforms.ToTensor(), | ||
]) | ||
return transform(image).unsqueeze(0) # 添加batch维度 | ||
|
||
# 进行推理 | ||
def infer(image_path): | ||
image = Image.open(image_path).convert("RGB") | ||
image_tensor = preprocess_image(image) | ||
image_tensor = image_tensor.to(device) | ||
|
||
with torch.no_grad(): | ||
prediction = model(image_tensor) | ||
|
||
return prediction | ||
|
||
# 显示结果 | ||
def show_result(image, predictions): | ||
image = cv2.imread(image_path) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
color_mapping = { | ||
1: (255, 0, 0), # 人用蓝色表示 | ||
2: (0, 255, 0), # 自行车用绿色表示 | ||
3: (0, 0, 255) # 汽车用红色表示 | ||
} | ||
for pred in predictions: | ||
masks = pred['masks'].cpu().numpy() | ||
labels = pred['labels'].cpu().numpy() | ||
scores = pred['scores'].cpu().numpy() | ||
for mask, label, score in zip(masks, labels, scores): | ||
if score > 0.5: | ||
mask = mask[0] | ||
mask = (mask > 0.5).astype(np.uint8) | ||
color = color_mapping.get(label.item(), (255, 255, 255)) | ||
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | ||
cv2.drawContours(image, contours, -1, color, 2) | ||
image=cv2.resize(image,(700,700)) | ||
cv2.imshow('Result', image) | ||
cv2.waitKey(0) | ||
cv2.destroyAllWindows() | ||
|
||
# 使用示例 | ||
image_path = 'street.jpg' # 替换为你的图像路径 | ||
prediction = infer(image_path) | ||
image = Image.open(image_path) | ||
image = show_result(image, prediction) | ||
|