Skip to content

Commit

Permalink
Update dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hako-mikan authored Feb 4, 2024
1 parent 7118c86 commit 7dee73d
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions trainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ def __getitem__(self, i):
batch = {}
latent, mask, cond1, cond2 = self.latents_conds[i]
batch["latent"] = latent.squeeze()
batch["mask"] = mask.squeeze() if mask is not None else None
batch["cond1"] = cond1 if isinstance(cond1, str) else cond1.squeeze() if cond1 is not None else None

if self.isxl:
batch["cond2"] = cond2 if isinstance(cond2, str) else cond2.squeeze() if cond2 is not None else None
if mask is not None:
batch["mask"] = mask.squeeze()

return batch

Expand Down Expand Up @@ -237,8 +238,7 @@ def resize_and_crop(ar_error, image, bucket_width, bucket_height, disable_upscal
t.image_buckets_raw[max].append([resized, alpha_mask, load_text_files(txt_path), load_text_files(cap_path), filename])
if t.image_mirroring:
flipped = resized.transpose(Image.FLIP_LEFT_RIGHT)
if alpha_mask is not None:
flipped_mask = torch.flip(alpha_mask, [1]) # 幅に対応する次元(ここでは1)で反転
flipped_mask = torch.flip(alpha_mask, [1]) if alpha_mask is not None else None
t.image_buckets_raw[max].append([flipped, flipped_mask, load_text_files(txt_path), load_text_files(cap_path), filename])

ar_errors = t.image_sub_ratios - ratio
Expand All @@ -253,8 +253,7 @@ def resize_and_crop(ar_error, image, bucket_width, bucket_height, disable_upscal
t.image_buckets_raw[sub].append([resized, alpha_mask, load_text_files(txt_path), load_text_files(cap_path), filename])
if t.image_mirroring:
flipped = resized.transpose(Image.FLIP_LEFT_RIGHT)
if alpha_mask is not None:
flipped_mask = torch.flip(alpha_mask, [1]) # 幅に対応する次元(ここでは1)で反転
flipped_mask = torch.flip(alpha_mask, [1]) if alpha_mask is not None else None
t.image_buckets_raw[sub].append([flipped, flipped_mask, load_text_files(txt_path), load_text_files(cap_path), filename])

ar_errors[indice] = ar_errors[indice] + 1
Expand Down

0 comments on commit 7dee73d

Please sign in to comment.