@@ -86,7 +86,16 @@ def _shuffle(lis):
86
86
return random .sample (lis , len (lis ))
87
87
88
88
89
- def _get_cutout_holes (height , width , min_holes = 8 , max_holes = 32 , min_height = 16 , max_height = 128 , min_width = 16 , max_width = 128 ):
89
+ def _get_cutout_holes (
90
+ height ,
91
+ width ,
92
+ min_holes = 8 ,
93
+ max_holes = 32 ,
94
+ min_height = 16 ,
95
+ max_height = 128 ,
96
+ min_width = 16 ,
97
+ max_width = 128 ,
98
+ ):
90
99
holes = []
91
100
for _n in range (random .randint (min_holes , max_holes )):
92
101
hole_height = random .randint (min_height , max_height )
@@ -103,12 +112,13 @@ def _generate_random_mask(image):
103
112
mask = zeros_like (image [:1 ])
104
113
holes = _get_cutout_holes (mask .shape [1 ], mask .shape [2 ])
105
114
for (x1 , y1 , x2 , y2 ) in holes :
106
- mask [:, y1 :y2 , x1 :x2 ] = 1.
115
+ mask [:, y1 :y2 , x1 :x2 ] = 1.0
107
116
if random .uniform (0 , 1 ) < 0.25 :
108
- mask .fill_ (1. )
117
+ mask .fill_ (1.0 )
109
118
masked_image = image * (mask < 0.5 )
110
119
return mask , masked_image
111
120
121
+
112
122
class PivotalTuningDatasetCapation (Dataset ):
113
123
"""
114
124
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
@@ -274,7 +284,10 @@ def __getitem__(self, index):
274
284
example ["instance_images" ] = self .image_transforms (instance_image )
275
285
276
286
if self .train_inpainting :
277
- example ["instance_masks" ], example ["instance_masked_images" ] = _generate_random_mask (example ["instance_images" ])
287
+ (
288
+ example ["instance_masks" ],
289
+ example ["instance_masked_images" ],
290
+ ) = _generate_random_mask (example ["instance_images" ])
278
291
279
292
if self .use_template :
280
293
assert self .token_map is not None
@@ -296,7 +309,7 @@ def __getitem__(self, index):
296
309
Image .open (self .mask_path [index % self .num_instance_images ])
297
310
)
298
311
* 0.5
299
- + 0.5
312
+ + 1.0
300
313
)
301
314
302
315
if self .h_flip and random .random () > 0.5 :
@@ -321,7 +334,10 @@ def __getitem__(self, index):
321
334
class_image = class_image .convert ("RGB" )
322
335
example ["class_images" ] = self .image_transforms (class_image )
323
336
if self .train_inpainting :
324
- example ["class_masks" ], example ["class_masked_images" ] = _generate_random_mask (example ["class_images" ])
337
+ (
338
+ example ["class_masks" ],
339
+ example ["class_masked_images" ],
340
+ ) = _generate_random_mask (example ["class_images" ])
325
341
example ["class_prompt_ids" ] = self .tokenizer (
326
342
self .class_prompt ,
327
343
padding = "do_not_pad" ,
0 commit comments