Skip to content

Commit

Permalink
Important update: auto mask inpainting
Browse files Browse the repository at this point in the history
  • Loading branch information
ZPdesu committed Dec 24, 2021
1 parent 9af131b commit 270aa73
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 130 deletions.
17 changes: 14 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,25 @@
## Description
Official Implementation of Barbershop.

**KEEP UPDATING !**

<span style="color:red">**KEEP UPDATING !**</span>


Option1: Produce realistic results:
```
python main.py --im_path1 90.png --im_path2 15.png --im_path3 117.png
python main.py --im_path1 90.png --im_path2 15.png --im_path3 117.png --sign realistic
```

Option2: Produce results faithful to the masks:
```
python main.py --im_path1 90.png --im_path2 15.png --im_path3 117.png --sign fidelity
```




## Updates
## <span style="color:red"> Updates</span>
#### <span style="color:red">24/12/2021 Important Update: Add improved semantic mask inpainting module. Please git pull the newest version.</span>

**18/12/2021** Add a rough version of the project.

Expand Down
2 changes: 1 addition & 1 deletion bash.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env bash

python main.py --im_path1 90.png --im_path2 15.png --im_path3 117.png
python main.py --im_path1 90.png --im_path2 15.png --im_path3 117.png --sign fidelity


50 changes: 12 additions & 38 deletions losses/align_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def __init__(self, opt):
self.style.eval()


tmp = torch.zeros(16).to(opt.device)
tmp[0] = 1
self.cross_entropy_wo_background = torch.nn.CrossEntropyLoss(weight=1 - tmp)
self.cross_entropy_only_background = torch.nn.CrossEntropyLoss(weight=tmp)



def cross_entropy_loss(self, down_seg, target_mask):
loss = self.opt.ce_lambda * self.cross_entropy(down_seg, target_mask)
Expand All @@ -28,42 +34,10 @@ def style_loss(self, im1, im2, mask1, mask2):
return loss


def cross_entropy_loss_wo_background(self, down_seg, target_mask):
loss = self.opt.ce_lambda * self.cross_entropy_wo_background(down_seg, target_mask)
return loss



#
# def _loss_l2(self, gen_im, ref_im, **kwargs):
# return self.l2(gen_im, ref_im)
#
#
# def _loss_lpips(self, gen_im, ref_im, **kwargs):
#
# return self.percept(gen_im, ref_im).sum()
#



#
# def forward(self, ref_im_H,ref_im_L, gen_im_H, gen_im_L):
#
# loss = 0
# loss_fun_dict = {
# 'l2': self._loss_l2,
# 'percep': self._loss_lpips,
# }
# losses = {}
# for weight, loss_type in self.parsed_loss:
# if loss_type == 'l2':
# var_dict = {
# 'gen_im': gen_im_H,
# 'ref_im': ref_im_H,
# }
# elif loss_type == 'percep':
# var_dict = {
# 'gen_im': gen_im_L,
# 'ref_im': ref_im_L,
# }
# tmp_loss = loss_fun_dict[loss_type](**var_dict)
# losses[loss_type] = tmp_loss
# loss += weight*tmp_loss
# return loss, losses
def cross_entropy_loss_only_background(self, down_seg, target_mask):
loss = self.opt.ce_lambda * self.cross_entropy_only_background(down_seg, target_mask)
return loss
31 changes: 17 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

def main(args):
ii2s = Embedding(args)

##### Option 1: input folder
# ii2s.invert_images_in_W()
# ii2s.invert_images_in_FS()

##### Option 2: image path
# ii2s.invert_images_in_W('input/face/28.png')
# ii2s.invert_images_in_FS('input/face/28.png')

#
# ##### Option 1: input folder
# # ii2s.invert_images_in_W()
# # ii2s.invert_images_in_FS()

# ##### Option 2: image path
# # ii2s.invert_images_in_W('input/face/28.png')
# # ii2s.invert_images_in_FS('input/face/28.png')
#
##### Option 3: image path list

# im_path1 = 'input/face/90.png'
Expand All @@ -40,20 +40,21 @@ def main(args):
ii2s.invert_images_in_FS([*im_set])

align = Alignment(args)
align.align_images(im_path1, im_path2, align_more_region=False)
align.align_images(im_path1, im_path2, sign=args.sign, align_more_region=False)
if im_path2 != im_path3:
align.align_images(im_path1, im_path3, align_more_region=False, save_intermediate=False)
align.align_images(im_path1, im_path3, sign=args.sign, align_more_region=False, save_intermediate=False)

blend = Blending(args)
blend.blend_images(im_path1, im_path2, im_path3)
blend.blend_images(im_path1, im_path2, im_path3, sign=args.sign)






if __name__ == "__main__":

parser = argparse.ArgumentParser(description='II2S')
parser = argparse.ArgumentParser(description='Barbershop')

# I/O arguments
parser.add_argument('--input_dir', type=str, default='input/face',
Expand All @@ -63,6 +64,8 @@ def main(args):
parser.add_argument('--im_path1', type=str, default='90.png', help='Identity image')
parser.add_argument('--im_path2', type=str, default='15.png', help='Structure image')
parser.add_argument('--im_path3', type=str, default='117.png', help='Appearance image')
parser.add_argument('--sign', type=str, default='realistic', help='realistic or fidelity results')


# StyleGAN2 setting
parser.add_argument('--size', type=int, default=1024)
Expand Down Expand Up @@ -111,4 +114,4 @@ def main(args):


args = parser.parse_args()
main(args)
main(args)
Loading

0 comments on commit 270aa73

Please sign in to comment.