Skip to content

Commit 799c17a

Browse files
committed
hotfix : background bias
1 parent 85fa6e7 commit 799c17a

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

lora_diffusion/dataset.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,16 @@ def _shuffle(lis):
8686
return random.sample(lis, len(lis))
8787

8888

89-
def _get_cutout_holes(height, width, min_holes=8, max_holes=32, min_height=16, max_height=128, min_width=16, max_width=128):
89+
def _get_cutout_holes(
90+
height,
91+
width,
92+
min_holes=8,
93+
max_holes=32,
94+
min_height=16,
95+
max_height=128,
96+
min_width=16,
97+
max_width=128,
98+
):
9099
holes = []
91100
for _n in range(random.randint(min_holes, max_holes)):
92101
hole_height = random.randint(min_height, max_height)
@@ -103,12 +112,13 @@ def _generate_random_mask(image):
103112
mask = zeros_like(image[:1])
104113
holes = _get_cutout_holes(mask.shape[1], mask.shape[2])
105114
for (x1, y1, x2, y2) in holes:
106-
mask[:, y1:y2, x1:x2] = 1.
115+
mask[:, y1:y2, x1:x2] = 1.0
107116
if random.uniform(0, 1) < 0.25:
108-
mask.fill_(1.)
117+
mask.fill_(1.0)
109118
masked_image = image * (mask < 0.5)
110119
return mask, masked_image
111120

121+
112122
class PivotalTuningDatasetCapation(Dataset):
113123
"""
114124
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
@@ -274,7 +284,10 @@ def __getitem__(self, index):
274284
example["instance_images"] = self.image_transforms(instance_image)
275285

276286
if self.train_inpainting:
277-
example["instance_masks"], example["instance_masked_images"] = _generate_random_mask(example["instance_images"])
287+
(
288+
example["instance_masks"],
289+
example["instance_masked_images"],
290+
) = _generate_random_mask(example["instance_images"])
278291

279292
if self.use_template:
280293
assert self.token_map is not None
@@ -296,7 +309,7 @@ def __getitem__(self, index):
296309
Image.open(self.mask_path[index % self.num_instance_images])
297310
)
298311
* 0.5
299-
+ 0.5
312+
+ 1.0
300313
)
301314

302315
if self.h_flip and random.random() > 0.5:
@@ -321,7 +334,10 @@ def __getitem__(self, index):
321334
class_image = class_image.convert("RGB")
322335
example["class_images"] = self.image_transforms(class_image)
323336
if self.train_inpainting:
324-
example["class_masks"], example["class_masked_images"] = _generate_random_mask(example["class_images"])
337+
(
338+
example["class_masks"],
339+
example["class_masked_images"],
340+
) = _generate_random_mask(example["class_images"])
325341
example["class_prompt_ids"] = self.tokenizer(
326342
self.class_prompt,
327343
padding="do_not_pad",

0 commit comments

Comments
 (0)