Skip to content

Commit

Permalink
Removed dense_crf and small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
milesial committed Dec 21, 2019
1 parent 4dcb7b8 commit de7507f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 24 deletions.
12 changes: 3 additions & 9 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F

from unet import UNet
from utils.data_vis import plot_img_and_mask
from utils.dataset import BasicDataset
from utils.crf import dense_crf


def predict_img(net,
full_img,
device,
scale_factor=1,
out_threshold=0.5,
use_dense_crf=False):
out_threshold=0.5):
net.eval()

img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
Expand All @@ -40,17 +38,14 @@ def predict_img(net,
tf = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(full_img.shape[1]),
transforms.Resize(full_img.size[1]),
transforms.ToTensor()
]
)

probs = tf(probs.cpu())
full_mask = probs.squeeze().cpu().numpy()

if use_dense_crf:
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

return full_mask > out_threshold


Expand Down Expand Up @@ -127,7 +122,6 @@ def mask_to_image(mask):
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
use_dense_crf=False,
device=device)

if not args.no_save:
Expand Down
28 changes: 14 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,18 @@ def get_args():
# faster convolutions, but more memory
# cudnn.benchmark = True

try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
2 changes: 1 addition & 1 deletion utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, imgs_dir, masks_dir, scale=1):
self.scale = scale
assert 0 < scale <= 1, 'Scale must be between 0 and 1'

self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
if not file.startswith('.')]
logging.info(f'Creating dataset with {len(self.ids)} examples')

Expand Down

0 comments on commit de7507f

Please sign in to comment.