Skip to content

Commit

Permalink
release
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoForte committed Mar 18, 2020
1 parent 759b08f commit d6b86ef
Show file tree
Hide file tree
Showing 15 changed files with 484 additions and 13 deletions.
251 changes: 251 additions & 0 deletions .ipynb_checkpoints/FBA Matting-checkpoint.ipynb

Large diffs are not rendered by default.

203 changes: 203 additions & 0 deletions FBA Matting.ipynb

Large diffs are not rendered by default.

18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ GPU memory >= 11GB for inference on Adobe Composition-1K testing set, more gener
- torch >= 1.4
- numpy
- opencv-python
#### Additional Packages for jupyter notebook
- matplotlib
- gdown (to download model inside notebook)


## Models
Expand All @@ -27,4 +30,17 @@ GPU memory >= 11GB for inference on Adobe Composition-1K testing set, more gener


## Prediction
We provide a script `demo.py` which gives the foreground, background and alpha results of our model.
We provide a script `demo.py` and jupyter notebook which both give the foreground, background and alpha predictions of our model.


## Citation

```
@article{forte2020fbamatting,
title = {F, B, Alpha Matting},
author = {Marco Forte and François Pitié},
journal = {CoRR},
volume = {abs/2003.07711},
year = {2020},
}
```
17 changes: 9 additions & 8 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def predict_fba_folder(model, args):
image_np = item_dict['image']
trimap_np = item_dict['trimap']

with torch.no_grad():
fg, bg, alpha = pred(image_np, trimap_np, model)
fg, bg, alpha = pred(image_np, trimap_np, model)

cv2.imwrite(os.path.join(save_dir, item_dict['name'][:-4] + '_fg.png'), fg[:, :, ::-1] * 255)
cv2.imwrite(os.path.join(save_dir, item_dict['name'][:-4] + '_bg.png'), bg[:, :, ::-1] * 255)
Expand All @@ -59,15 +58,17 @@ def pred(image_np: np.ndarray, trimap_np: np.ndarray, model) -> np.ndarray:
image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)

image_torch = np_to_torch(image_scale_np)
trimap_torch = np_to_torch(trimap_scale_np)
with torch.no_grad():

trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np))
image_transformed_torch = groupnorm_normalise_image(image_torch.clone(), format='nchw')
image_torch = np_to_torch(image_scale_np)
trimap_torch = np_to_torch(trimap_scale_np)

output = model(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch)
trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np))
image_transformed_torch = groupnorm_normalise_image(image_torch.clone(), format='nchw')

output = cv2.resize(output[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4)
output = model(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch)

output = cv2.resize(output[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4)
alpha = output[:, :, 0]
fg = output[:, :, 1:4]
bg = output[:, :, 4:7]
Expand Down
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
8 changes: 4 additions & 4 deletions networks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,21 +332,21 @@ def forward(self, conv_out, img, indices, two_chan_trimap):
ppm_out.append(nn.functional.interpolate(
pool_scale(conv5),
(input_size[2], input_size[3]),
mode='bilinear'))
mode='bilinear', align_corners=False))
ppm_out = torch.cat(ppm_out, 1)
x = self.conv_up1(ppm_out)

x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear')
x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)

x = torch.cat((x, conv_out[-4]), 1)

x = self.conv_up2(x)
x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear')
x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)

x = torch.cat((x, conv_out[-5]), 1)
x = self.conv_up3(x)

x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear')
x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)

output = self.conv_up4(x)
Expand Down

0 comments on commit d6b86ef

Please sign in to comment.