Skip to content

Commit

Permalink
small fix on memory leak and segmentation training
Browse files Browse the repository at this point in the history
  • Loading branch information
j96w committed Apr 7, 2019
1 parent dd8ff1a commit 88aaa21
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
4 changes: 2 additions & 2 deletions lib/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch.backends.cudnn as cudnn
from lib.knn.__init__ import KNearestNeighbor

knn = KNearestNeighbor(1)

def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list):
knn = KNearestNeighbor(1)
bs, num_p, _ = pred_c.size()

pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
Expand Down Expand Up @@ -67,7 +67,7 @@ def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points,
new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()

# print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item())

del knn
return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach()


Expand Down
4 changes: 2 additions & 2 deletions lib/loss_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch.backends.cudnn as cudnn
from lib.knn.__init__ import KNearestNeighbor

knn = KNearestNeighbor(1)

def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list):
knn = KNearestNeighbor(1)
pred_r = pred_r.view(1, 1, -1)
pred_t = pred_t.view(1, 1, -1)
bs, num_p, _ = pred_r.size()
Expand Down Expand Up @@ -60,7 +60,7 @@ def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_poin
new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()

# print('------------> ', dis.item(), idx[0].item())

del knn
return dis, new_points.detach(), new_target.detach()


Expand Down
13 changes: 8 additions & 5 deletions vanilla_segmentation/data_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
from PIL import ImageFilter

class SegDataset(data.Dataset):
def __init__(self, root_dir, txtlist, use_noise, num=1000):
def __init__(self, root_dir, txtlist, use_noise, length):
self.path = []
self.real_path = []
self.use_noise = use_noise
self.num = num
self.root = root_dir
input_file = open(txtlist)
while 1:
Expand All @@ -33,13 +32,17 @@ def __init__(self, root_dir, txtlist, use_noise, num=1000):
self.real_path.append(copy.deepcopy(input_line))
input_file.close()

self.length = length
self.data_len = len(self.path)
self.back_len = len(self.real_path)
self.length = len(self.path)

self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.back_front = np.array([[1 for i in range(640)] for j in range(480)])

def __getitem__(self, index):
def __getitem__(self, idx):
index = random.randint(0, self.data_len - 10)

label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[index])))
meta = scio.loadmat('{0}/{1}-meta.mat'.format(self.root, self.path[index]))
if not self.use_noise:
Expand All @@ -51,7 +54,7 @@ def __getitem__(self, index):
rgb = Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB")
rgb = ImageEnhance.Brightness(rgb).enhance(1.5).filter(ImageFilter.GaussianBlur(radius=0.8))
rgb = np.array(self.trancolor(rgb))
seed = random.randint(10, self.back_len - 10)
seed = random.randint(0, self.back_len - 10)
back = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[seed])).convert("RGB")))
back_label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[seed])))
mask = ma.getmaskarray(ma.masked_equal(label, 0))
Expand Down
6 changes: 3 additions & 3 deletions vanilla_segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True)
dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers))
test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=int(opt.workers))
test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers))

print(len(dataset), len(test_dataset))

Expand Down

0 comments on commit 88aaa21

Please sign in to comment.