Skip to content

Commit

Permalink
remove redundancies
Browse files Browse the repository at this point in the history
  • Loading branch information
yisol committed Apr 22, 2024
1 parent 20e97e4 commit ff17538
Showing 1 changed file with 58 additions and 188 deletions.
246 changes: 58 additions & 188 deletions inference_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,16 @@ def pil_to_tensor(images):
return images


class DresscodeTestDataset(data.Dataset):
class VitonHDTestDataset(data.Dataset):
def __init__(
self,
dataroot_path: str,
phase: Literal["train", "test"],
order: Literal["paired", "unpaired"] = "paired",
category = "upper_body",
size: Tuple[int, int] = (512, 384),
):
super(DresscodeTestDataset, self).__init__()
self.dataroot = os.path.join(dataroot_path,category)
super(VitonHDTestDataset, self).__init__()
self.dataroot = dataroot_path
self.phase = phase
self.height = size[0]
self.width = size[1]
Expand All @@ -115,95 +114,87 @@ def __init__(
]
)
self.toTensor = transforms.ToTensor()

with open(
os.path.join(dataroot_path, phase, "vitonhd_" + phase + "_tagged.json"), "r"
) as file1:
data1 = json.load(file1)

annotation_list = [
"sleeveLength",
"neckLine",
"item",
]

self.annotation_pair = {}
for k, v in data1.items():
for elem in v:
annotation_str = ""
for template in annotation_list:
for tag in elem["tag_info"]:
if (
tag["tag_name"] == template
and tag["tag_category"] is not None
):
annotation_str += tag["tag_category"]
annotation_str += " "
self.annotation_pair[elem["file_name"]] = annotation_str

self.order = order
self.radius = 5
self.category = category
self.toTensor = transforms.ToTensor()

im_names = []
c_names = []
dataroot_names = []


if phase == "train":
filename = os.path.join(dataroot_path,category, f"{phase}_pairs.txt")
filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
else:
filename = os.path.join(dataroot_path,category, f"{phase}_pairs_{order}.txt")
filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")

with open(filename, "r") as f:
for line in f.readlines():
im_name, c_name = line.strip().split()
if phase == "train":
im_name, _ = line.strip().split()
c_name = im_name
else:
if order == "paired":
im_name, _ = line.strip().split()
c_name = im_name
else:
im_name, c_name = line.strip().split()

im_names.append(im_name)
c_names.append(c_name)


file_path = os.path.join(dataroot_path,category,"dc_caption.txt")

self.annotation_pair = {}
with open(file_path, "r") as file:
for line in file:
parts = line.strip().split(" ")
self.annotation_pair[parts[0]] = ' '.join(parts[1:])

dataroot_names.append(dataroot_path)

self.im_names = im_names
self.c_names = c_names
self.dataroot_names = dataroot_names
self.clip_processor = CLIPImageProcessor()
def __getitem__(self, index):
c_name = self.c_names[index]
im_name = self.im_names[index]
if c_name in self.annotation_pair:
cloth_annotation = self.annotation_pair[c_name]
else:
cloth_annotation = self.category
cloth = Image.open(os.path.join(self.dataroot, "images", c_name))
cloth_annotation = "shirts"
cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name))

im_pil_big = Image.open(
os.path.join(self.dataroot, "images", im_name)
os.path.join(self.dataroot, self.phase, "image", im_name)
).resize((self.width,self.height))
image = self.transform(im_pil_big)




skeleton = Image.open(os.path.join(self.dataroot, 'skeletons', im_name.replace("_0", "_5")))
skeleton = skeleton.resize((self.width, self.height))
skeleton = self.transform(skeleton)

# Label Map
parse_name = im_name.replace('_0.jpg', '_4.png')
im_parse = Image.open(os.path.join(self.dataroot, 'label_maps', parse_name))
im_parse = im_parse.resize((self.width, self.height), Image.NEAREST)
parse_array = np.array(im_parse)

# Load pose points
pose_name = im_name.replace('_0.jpg', '_2.json')
with open(os.path.join(self.dataroot, 'keypoints', pose_name), 'r') as f:
pose_label = json.load(f)
pose_data = pose_label['keypoints']
pose_data = np.array(pose_data)
pose_data = pose_data.reshape((-1, 4))

point_num = pose_data.shape[0]
pose_map = torch.zeros(point_num, self.height, self.width)
r = self.radius * (self.height / 512.0)
for i in range(point_num):
one_map = Image.new('L', (self.width, self.height))
draw = ImageDraw.Draw(one_map)
point_x = np.multiply(pose_data[i, 0], self.width / 384.0)
point_y = np.multiply(pose_data[i, 1], self.height / 512.0)
if point_x > 1 and point_y > 1:
draw.rectangle((point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white')
one_map = self.toTensor(one_map)
pose_map[i] = one_map[0]

agnostic_mask = self.get_agnostic(parse_array, pose_data, self.category, (self.width,self.height))
# agnostic_mask = transforms.functional.resize(agnostic_mask, (self.height, self.width),
# interpolation=transforms.InterpolationMode.NEAREST)

mask = 1 - agnostic_mask
im_mask = image * agnostic_mask

mask = Image.open(os.path.join(self.dataroot, self.phase, "agnostic-mask", im_name.replace('.jpg','_mask.png'))).resize((self.width,self.height)).convert("L")
mask = self.toTensor(mask)
mask = mask[:1]
mask = 1-mask
im_mask = image * mask

pose_img = Image.open(
os.path.join(self.dataroot, "image-densepose", im_name)
os.path.join(self.dataroot, self.phase, "image-densepose", im_name)
)
pose_img = self.transform(pose_img) # [-1,1]

Expand All @@ -228,126 +219,6 @@ def __len__(self):



def get_agnostic(self,parse_array, pose_data, category, size):
parse_shape = (parse_array > 0).astype(np.float32)

parse_head = (parse_array == 1).astype(np.float32) + \
(parse_array == 2).astype(np.float32) + \
(parse_array == 3).astype(np.float32) + \
(parse_array == 11).astype(np.float32)

parser_mask_fixed = (parse_array == label_map["hair"]).astype(np.float32) + \
(parse_array == label_map["left_shoe"]).astype(np.float32) + \
(parse_array == label_map["right_shoe"]).astype(np.float32) + \
(parse_array == label_map["hat"]).astype(np.float32) + \
(parse_array == label_map["sunglasses"]).astype(np.float32) + \
(parse_array == label_map["scarf"]).astype(np.float32) + \
(parse_array == label_map["bag"]).astype(np.float32)

parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32)

arms = (parse_array == 14).astype(np.float32) + (parse_array == 15).astype(np.float32)

if category == 'dresses':
label_cat = 7
parse_mask = (parse_array == 7).astype(np.float32) + \
(parse_array == 12).astype(np.float32) + \
(parse_array == 13).astype(np.float32)
parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))

elif category == 'upper_body':
label_cat = 4
parse_mask = (parse_array == 4).astype(np.float32)

parser_mask_fixed += (parse_array == label_map["skirt"]).astype(np.float32) + \
(parse_array == label_map["pants"]).astype(np.float32)

parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
elif category == 'lower_body':
label_cat = 6
parse_mask = (parse_array == 6).astype(np.float32) + \
(parse_array == 12).astype(np.float32) + \
(parse_array == 13).astype(np.float32)

parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \
(parse_array == 14).astype(np.float32) + \
(parse_array == 15).astype(np.float32)
parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))

parse_head = torch.from_numpy(parse_head) # [0,1]
parse_mask = torch.from_numpy(parse_mask) # [0,1]
parser_mask_fixed = torch.from_numpy(parser_mask_fixed)
parser_mask_changeable = torch.from_numpy(parser_mask_changeable)

# dilation
parse_without_cloth = np.logical_and(parse_shape, np.logical_not(parse_mask))
parse_mask = parse_mask.cpu().numpy()

width = size[0]
height = size[1]

im_arms = Image.new('L', (width, height))
arms_draw = ImageDraw.Draw(im_arms)
if category == 'dresses' or category == 'upper_body':
shoulder_right = tuple(np.multiply(pose_data[2, :2], height / 512.0))
shoulder_left = tuple(np.multiply(pose_data[5, :2], height / 512.0))
elbow_right = tuple(np.multiply(pose_data[3, :2], height / 512.0))
elbow_left = tuple(np.multiply(pose_data[6, :2], height / 512.0))
wrist_right = tuple(np.multiply(pose_data[4, :2], height / 512.0))
wrist_left = tuple(np.multiply(pose_data[7, :2], height / 512.0))
if wrist_right[0] <= 1. and wrist_right[1] <= 1.:
if elbow_right[0] <= 1. and elbow_right[1] <= 1.:
arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right], 'white', 30, 'curve')
else:
arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right], 'white', 30,
'curve')
elif wrist_left[0] <= 1. and wrist_left[1] <= 1.:
if elbow_left[0] <= 1. and elbow_left[1] <= 1.:
arms_draw.line([shoulder_left, shoulder_right, elbow_right, wrist_right], 'white', 30, 'curve')
else:
arms_draw.line([elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right], 'white', 30,
'curve')
else:
arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right], 'white',
30, 'curve')

if height > 512:
im_arms = cv2.dilate(np.float32(im_arms), np.ones((10, 10), np.uint16), iterations=5)
elif height > 256:
im_arms = cv2.dilate(np.float32(im_arms), np.ones((5, 5), np.uint16), iterations=5)
hands = np.logical_and(np.logical_not(im_arms), arms)
parse_mask += im_arms
parser_mask_fixed += hands

# delete neck
parse_head_2 = torch.clone(parse_head)
if category == 'dresses' or category == 'upper_body':
points = []
points.append(np.multiply(pose_data[2, :2], height / 512.0))
points.append(np.multiply(pose_data[5, :2], height / 512.0))
x_coords, y_coords = zip(*points)
A = np.vstack([x_coords, np.ones(len(x_coords))]).T
m, c = lstsq(A, y_coords, rcond=None)[0]
for i in range(parse_array.shape[1]):
y = i * m + c
parse_head_2[int(y - 20 * (height / 512.0)):, i] = 0

parser_mask_fixed = np.logical_or(parser_mask_fixed, np.array(parse_head_2, dtype=np.uint16))
parse_mask += np.logical_or(parse_mask, np.logical_and(np.array(parse_head, dtype=np.uint16),
np.logical_not(np.array(parse_head_2, dtype=np.uint16))))

if height > 512:
parse_mask = cv2.dilate(parse_mask, np.ones((20, 20), np.uint16), iterations=5)
elif height > 256:
parse_mask = cv2.dilate(parse_mask, np.ones((10, 10), np.uint16), iterations=5)
else:
parse_mask = cv2.dilate(parse_mask, np.ones((5, 5), np.uint16), iterations=5)
parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask))
parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed)
agnostic_mask = parse_mask_total.unsqueeze(0)
return agnostic_mask




def main():
Expand Down Expand Up @@ -388,8 +259,8 @@ def main():
torch_dtype=torch.float16,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet_dc",
"yisol/IDM-VTON-DC",
subfolder="unet",
torch_dtype=torch.float16,
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
Expand Down Expand Up @@ -452,11 +323,10 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

test_dataset = DresscodeTestDataset(
test_dataset = VitonHDTestDataset(
dataroot_path=args.data_dir,
phase="test",
order="unpaired" if args.unpaired else "paired",
category = args.category,
size=(args.height, args.width),
)
test_dataloader = torch.utils.data.DataLoader(
Expand Down

0 comments on commit ff17538

Please sign in to comment.