Skip to content

Commit

Permalink
Merge branch 'develop' into acclerate
Browse files Browse the repository at this point in the history
# Conflicts:
#	train_no_accelerator.py
  • Loading branch information
idonahum1 committed Apr 26, 2024
2 parents 6e34472 + f554635 commit accaabb
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 200 deletions.
13 changes: 12 additions & 1 deletion datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.utils.data import Dataset
import torch
import numpy as np
from utils import arcface_utils

imagenet_templates_small = [
"a photo of a {}",
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(
self,
data_root,
tokenizer,
face_embedding_func=None,
img_subfolder='images',
size=512,
interpolation="bicubic",
Expand All @@ -67,6 +69,7 @@ def __init__(
self.tokenizer = tokenizer
self.size = size
self.placeholder_token = placeholder_token
self.face_embedding_func = face_embedding_func
img_dir = os.path.join(data_root, img_subfolder)
self.image_paths = []
self.image_paths += [os.path.join(img_dir, file_path) for file_path in os.listdir(img_dir) if is_image(file_path)]
Expand Down Expand Up @@ -118,6 +121,9 @@ def _prepare_image(self, example: dict, idx: int):
pixel_values = self._preprocess(raw_image)
example["pixel_values"] = pixel_values
example["pixel_values_clip"] = pixel_values_clip
face_analysis = self.face_embedding_func(raw_image)
face_analysis = arcface_utils.get_largest_bbox_face_analysis(face_analysis)
example["face_embedding"] = face_analysis['embedding']
return example

def _find_placeholder_index(self, text: str):
Expand All @@ -139,12 +145,13 @@ def __init__(
self,
data_root,
tokenizer,
face_embedding_func=None,
img_subfolder='images',
mask_subfolder='masks',
size=512,
interpolation="bicubic",
placeholder_token="*",
template="a photo of a {}",
template="a photo of a {}"
):
super().__init__(data_root=data_root, tokenizer=tokenizer,img_subfolder=img_subfolder,
size=size, interpolation=interpolation,
Expand All @@ -154,6 +161,7 @@ def __init__(
mask_dir = os.path.join(data_root, mask_subfolder)
self.masks_paths += [os.path.join(mask_dir, file_path) for file_path in os.listdir(mask_dir) if is_image(file_path)]
self.masks_paths = sorted(self.masks_paths)
self.face_embedding_func = face_embedding_func

def _prepare_image(self, example: dict, idx: int):
image_path = self.image_paths[idx]
Expand All @@ -177,6 +185,9 @@ def _prepare_image(self, example: dict, idx: int):
pixel_values = self._preprocess(raw_image)
example["pixel_values"] = pixel_values
example["pixel_values_clip"] = pixel_values_clip
face_analysis = self.face_embedding_func(reshaped_img)
face_analysis = arcface_utils.get_largest_bbox_face_analysis(face_analysis)
example["face_embedding"] = face_analysis['embedding']
return example


Expand Down
6 changes: 3 additions & 3 deletions datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np


NUM_OF_IMAGES_IN_CELEBAHQ = 30000
NUM_OF_IMAGES_IN_CELEBAHQ = 200
MASKS_LABEL_LIST_CELEBAHQ = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth']


Expand All @@ -30,6 +30,6 @@ def create_celebahq_masks(masks_path, save_path):


if __name__ == "__main__":
folder_base = '/Users/ido.nahum/Downloads/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
folder_save = '/Users/ido.nahum/dev/photoVerse/CelebaHQMask/masks'
folder_base = r'c:\Users\HAIMZIS\Downloads\CelebAMask-HQ\CelebAMask-HQ\CelebAMask-HQ-mask-anno'
folder_save = 'data\celeb\masks'
create_celebahq_masks(folder_base, folder_save)
Loading

0 comments on commit accaabb

Please sign in to comment.