-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1589df5
Showing
7 changed files
with
1,719 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from __future__ import division | ||
#torch | ||
import torch | ||
from torch.autograd import Variable | ||
from torch.utils import data | ||
|
||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.nn.init as init | ||
import torch.utils.model_zoo as model_zoo | ||
from torchvision import models | ||
|
||
# general libs | ||
import cv2 | ||
import matplotlib.pyplot as plt | ||
from PIL import Image | ||
import numpy as np | ||
import time | ||
import os | ||
import copy | ||
|
||
|
||
class font: | ||
PURPLE = '\033[95m' | ||
CYAN = '\033[96m' | ||
DARKCYAN = '\033[36m' | ||
BLUE = '\033[94m' | ||
GREEN = '\033[92m' | ||
YELLOW = '\033[93m' | ||
RED = '\033[91m' | ||
BOLD = '\033[1m' | ||
UNDERLINE = '\033[4m' | ||
END = '\033[0m' | ||
|
||
def ToCuda(xs): | ||
if torch.cuda.is_available(): | ||
if isinstance(xs, list) or isinstance(xs, tuple): | ||
return [x.cuda() for x in xs] | ||
else: | ||
return xs.cuda() | ||
else: | ||
return xs | ||
|
||
|
||
def pad_divide_by(in_list, d, in_size): | ||
out_list = [] | ||
h, w = in_size | ||
if h % d > 0: | ||
new_h = h + d - h % d | ||
else: | ||
new_h = h | ||
if w % d > 0: | ||
new_w = w + d - w % d | ||
else: | ||
new_w = w | ||
lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) | ||
lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) | ||
pad_array = (int(lw), int(uw), int(lh), int(uh)) | ||
for inp in in_list: | ||
out_list.append(F.pad(inp, pad_array)) | ||
return out_list, pad_array | ||
|
||
|
||
|
||
def overlay_davis(image,mask,colors=[255,0,0],cscale=2,alpha=0.4): | ||
""" Overlay segmentation on top of RGB image. from davis official""" | ||
# import skimage | ||
from scipy.ndimage.morphology import binary_erosion, binary_dilation | ||
|
||
colors = np.reshape(colors, (-1, 3)) | ||
colors = np.atleast_2d(colors) * cscale | ||
|
||
im_overlay = image.copy() | ||
object_ids = np.unique(mask) | ||
|
||
for object_id in object_ids[1:]: | ||
# Overlay color on binary mask | ||
foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) | ||
binary_mask = mask == object_id | ||
|
||
# Compose image | ||
im_overlay[binary_mask] = foreground[binary_mask] | ||
|
||
# countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask | ||
countours = binary_dilation(binary_mask) ^ binary_mask | ||
# countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask | ||
im_overlay[countours,:] = 0 | ||
|
||
return im_overlay.astype(image.dtype) | ||
|
||
def iou(Es, Ms): | ||
#Es: torch.Size([4, 11, 2, 100, 100]) | ||
#Ms: torch.Size([4, 11, 2, 100, 100]) | ||
batch_size, _, num_frames, _, _ = Es.size() | ||
mean_iou = 0.0 | ||
for b in range(batch_size): | ||
for f in range(num_frames): | ||
pred = Es[b,:,f].unsqueeze(0) | ||
#pred: torch.Size([1, 11, 100, 100]) | ||
gt = Ms[b,:,f].cpu().data.numpy() | ||
#gt: (11, 100, 100) | ||
pred = torch.cat((1-pred, pred), dim=0) | ||
#pred: torch.Size([2, 11, 100, 100]) | ||
pred = ToLabel(pred.cpu().data.numpy()) | ||
#pred: (11, 100, 100) | ||
agg = pred + gt | ||
i = float(np.sum(agg == 2)) | ||
u = float(np.sum(agg > 0)) | ||
mean_iou += i/u | ||
|
||
return mean_iou/(batch_size*num_frames) | ||
|
||
def ToLabel(E): | ||
fgs = np.argmax(E, axis=0).astype(np.float32) | ||
return fgs.astype(np.uint8) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
|
||
class diceloss(nn.Module): | ||
def __init__(self): | ||
super(diceloss, self).__init__() | ||
def __call__(self,pred, target): | ||
smooth = 1. | ||
iflat = pred.contiguous().view(-1) | ||
tflat = target.contiguous().view(-1) | ||
intersection = (iflat * tflat).sum() | ||
A_sum = torch.sum(iflat * iflat) | ||
B_sum = torch.sum(tflat * tflat) | ||
return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) ) |
Oops, something went wrong.