Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
levihsu committed Mar 8, 2024
2 parents 75d89e1 + aaa8bd7 commit c5945e2
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 100 deletions.
22 changes: 0 additions & 22 deletions preprocess/humanparsing/aigc_run_parsing.py

This file was deleted.

102 changes: 26 additions & 76 deletions preprocess/humanparsing/parsing_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import torch
import numpy as np
import cv2
import networks
from collections import OrderedDict
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from datasets.simple_extractor_dataset import SimpleFolderDataset
Expand Down Expand Up @@ -120,45 +118,7 @@ def refine_hole(parsing_result_filled, parsing_result, arm_mask):
cv2.drawContours(refine_hole_mask, contours, i, color=255, thickness=-1)
return refine_hole_mask + arm_mask



def load_atr_model():
# load atr model
num_classes = 18
label = ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf']
model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)
state_dict = torch.load(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing/exp-schp-201908301523-atr.pth'))['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.cuda()
model.eval()
# load lip model
return model

def load_lip_model():
# load atr model
num_classes = 20
label = ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat',
'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm',
'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)
state_dict = torch.load(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing/exp-schp-201908261155-lip.pth'))['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.cuda()
model.eval()
# load lip model
return model

def inference(model, lip_model, input_dir):
# load datasetloader
def onnx_inference(session, lip_session, input_dir):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
Expand All @@ -172,19 +132,14 @@ def inference(model, lip_model, input_dir):
s = meta['scale'].numpy()[0]
w = meta['width'].numpy()[0]
h = meta['height'].numpy()[0]

output = model(image.cuda())

output = session.run(None, {"input.1": image.numpy().astype(np.float32)})
upsample = torch.nn.Upsample(size=[512, 512], mode='bilinear', align_corners=True)
upsample_output = upsample(output[0][-1][0].unsqueeze(0))
upsample_output = upsample(torch.from_numpy(output[1][0]).unsqueeze(0))
upsample_output = upsample_output.squeeze()
upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC
logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=[512, 512])

# delete irregular classes, e.g. pants/ skirts over clothes
parsing_result = np.argmax(logits_result, axis=2)
parsing_result = np.pad(parsing_result, pad_width=1, mode='constant', constant_values=0)

# try holefilling the clothes part
arm_mask = (parsing_result == 14).astype(np.float32) \
+ (parsing_result == 15).astype(np.float32)
Expand All @@ -193,46 +148,41 @@ def inference(model, lip_model, input_dir):
dst = hole_fill(img.astype(np.uint8))
parsing_result_filled = dst / 255 * 4
parsing_result_woarm = np.where(parsing_result_filled == 4, parsing_result_filled, parsing_result)

# add back arm and refined hole between arm and cloth
refine_hole_mask = refine_hole(parsing_result_filled.astype(np.uint8), parsing_result.astype(np.uint8),
arm_mask.astype(np.uint8))
parsing_result = np.where(refine_hole_mask, parsing_result, parsing_result_woarm)
# remove padding
parsing_result = parsing_result[1:-1, 1:-1]


dataset_lip = SimpleFolderDataset(root=input_dir, input_size=[473, 473], transform=transform)
dataloader_lip = DataLoader(dataset_lip)
with torch.no_grad():
for _, batch in enumerate(tqdm(dataloader_lip)):
image, meta = batch
c = meta['center'].numpy()[0]
s = meta['scale'].numpy()[0]
w = meta['width'].numpy()[0]
h = meta['height'].numpy()[0]

output_lip = lip_model(image.cuda())

upsample = torch.nn.Upsample(size=[473, 473], mode='bilinear', align_corners=True)
upsample_output_lip = upsample(output_lip[0][-1][0].unsqueeze(0))
upsample_output_lip = upsample_output_lip.squeeze()
upsample_output_lip = upsample_output_lip.permute(1, 2, 0) # CHW -> HWC
logits_result_lip = transform_logits(upsample_output_lip.data.cpu().numpy(), c, s, w, h, input_size=[473, 473])
parsing_result_lip = np.argmax(logits_result_lip, axis=2)
dataset_lip = SimpleFolderDataset(root=input_dir, input_size=[473, 473], transform=transform)
dataloader_lip = DataLoader(dataset_lip)
with torch.no_grad():
for _, batch in enumerate(tqdm(dataloader_lip)):
image, meta = batch
c = meta['center'].numpy()[0]
s = meta['scale'].numpy()[0]
w = meta['width'].numpy()[0]
h = meta['height'].numpy()[0]

output_lip = lip_session.run(None, {"input.1": image.numpy().astype(np.float32)})
upsample = torch.nn.Upsample(size=[473, 473], mode='bilinear', align_corners=True)
upsample_output_lip = upsample(torch.from_numpy(output_lip[1][0]).unsqueeze(0))
upsample_output_lip = upsample_output_lip.squeeze()
upsample_output_lip = upsample_output_lip.permute(1, 2, 0) # CHW -> HWC
logits_result_lip = transform_logits(upsample_output_lip.data.cpu().numpy(), c, s, w, h,
input_size=[473, 473])
parsing_result_lip = np.argmax(logits_result_lip, axis=2)
# add neck parsing result
neck_mask = np.logical_and(np.logical_not((parsing_result_lip == 13).astype(np.float32)), (parsing_result == 11).astype(np.float32))
# filter out small part of neck
neck_mask = refine_mask(neck_mask)
# Image.fromarray(((neck_mask > 0) * 127.5 + 127.5).astype(np.uint8)).save("neck_mask.jpg")
neck_mask = np.logical_and(np.logical_not((parsing_result_lip == 13).astype(np.float32)),
(parsing_result == 11).astype(np.float32))
parsing_result = np.where(neck_mask, 18, parsing_result)
palette = get_palette(19)
parsing_result_path = os.path.join('parsed.png')
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
output_img.putpalette(palette)
# output_img.save(parsing_result_path)
face_mask = torch.from_numpy((parsing_result == 11).astype(np.float32))

face_mask = torch.from_numpy((parsing_result == 11).astype(np.float32))

return output_img, face_mask



29 changes: 29 additions & 0 deletions preprocess/humanparsing/run_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pdb
from pathlib import Path
import sys
import os
import onnxruntime as ort
PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
sys.path.insert(0, str(PROJECT_ROOT))
from parsing_api import onnx_inference
import torch


class Parsing:
def __init__(self, gpu_id: int):
self.gpu_id = gpu_id
torch.cuda.set_device(gpu_id)
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
session_options.add_session_config_entry('gpu_id', str(gpu_id))
self.session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing/parsing_atr.onnx'),
sess_options=session_options, providers=['CPUExecutionProvider'])
self.lip_session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing/parsing_lip.onnx'),
sess_options=session_options, providers=['CPUExecutionProvider'])


def __call__(self, input_image):
torch.cuda.set_device(self.gpu_id)
parsed_image, face_mask = onnx_inference(self.session, self.lip_session, input_image)
return parsed_image, face_mask
2 changes: 1 addition & 1 deletion run/gradio_ootd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import time
from preprocess.openpose.run_openpose import OpenPose
from preprocess.humanparsing.aigc_run_parsing import Parsing
from preprocess.humanparsing.run_parsing import Parsing
from ootd.inference_ootd_hd import OOTDiffusionHD
from ootd.inference_ootd_dc import OOTDiffusionDC

Expand Down
2 changes: 1 addition & 1 deletion run/run_ootd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
sys.path.insert(0, str(PROJECT_ROOT))

from preprocess.openpose.run_openpose import OpenPose
from preprocess.humanparsing.aigc_run_parsing import Parsing
from preprocess.humanparsing.run_parsing import Parsing
from ootd.inference_ootd_hd import OOTDiffusionHD
from ootd.inference_ootd_dc import OOTDiffusionDC

Expand Down

0 comments on commit c5945e2

Please sign in to comment.