diff --git a/datasets.py b/datasets.py index 553a2cf9b..9b62ded5d 100644 --- a/datasets.py +++ b/datasets.py @@ -11,7 +11,6 @@ import progressbar import sys import torchvision.transforms as transforms -import utils import argparse import json @@ -22,9 +21,9 @@ def __init__(self, root, npoints = 2500, classification = False, class_choice = self.root = root self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') self.cat = {} - + self.classification = classification - + with open(self.catfile, 'r') as f: for line in f: ls = line.strip().split() @@ -32,7 +31,7 @@ def __init__(self, root, npoints = 2500, classification = False, class_choice = #print(self.cat) if not class_choice is None: self.cat = {k:v for k,v in self.cat.items() if k in class_choice} - + self.meta = {} for item in self.cat: #print('category', item) @@ -45,19 +44,20 @@ def __init__(self, root, npoints = 2500, classification = False, class_choice = fns = fns[:int(len(fns) * 0.9)] else: fns = fns[int(len(fns) * 0.9):] - + #print(os.path.basename(fns)) for fn in fns: - token = (os.path.splitext(os.path.basename(fn))[0]) + token = (os.path.splitext(os.path.basename(fn))[0]) self.meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg'))) - + self.datapath = [] for item in self.cat: for fn in self.meta[item]: self.datapath.append((item, fn[0], fn[1])) - - - self.classes = dict(zip(self.cat, range(len(self.cat)))) + + + self.classes = dict(zip(self.cat, range(len(self.cat)))) + print(self.classes) self.num_seg_classes = 0 if not self.classification: for i in range(len(self.datapath)/50): @@ -65,15 +65,15 @@ def __init__(self, root, npoints = 2500, classification = False, class_choice = if l > self.num_seg_classes: self.num_seg_classes = l #print(self.num_seg_classes) - - + + def __getitem__(self, index): fn = self.datapath[index] cls = self.classes[self.datapath[index][0]] point_set = np.loadtxt(fn[1]).astype(np.float32) seg = np.loadtxt(fn[2]).astype(np.int64) #print(point_set.shape, seg.shape) - + choice = np.random.choice(len(seg), self.npoints, replace=True) #resample point_set = point_set[choice, :] @@ -85,7 +85,7 @@ def __getitem__(self, index): return point_set, cls else: return point_set, seg - + def __len__(self): return len(self.datapath) @@ -96,8 +96,8 @@ def __len__(self): print(len(d)) ps, seg = d[0] print(ps.size(), ps.type(), seg.size(),seg.type()) - + d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True) print(len(d)) ps, cls = d[0] - print(ps.size(), ps.type(), cls.size(),cls.type()) \ No newline at end of file + print(ps.size(), ps.type(), cls.size(),cls.type()) diff --git a/pointnet.py b/pointnet.py index 490a3cb6f..f13c7dc2c 100644 --- a/pointnet.py +++ b/pointnet.py @@ -15,7 +15,6 @@ import numpy as np import matplotlib.pyplot as plt import pdb -import utils import torch.nn.functional as F @@ -31,14 +30,14 @@ def __init__(self, num_points = 2500): self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, 9) self.relu = nn.ReLU() - + self.bn1 = nn.BatchNorm1d(64) self.bn2 = nn.BatchNorm1d(128) self.bn3 = nn.BatchNorm1d(1024) self.bn4 = nn.BatchNorm1d(512) self.bn5 = nn.BatchNorm1d(256) - - + + def forward(self, x): batchsize = x.size()[0] x = F.relu(self.bn1(self.conv1(x))) @@ -57,12 +56,12 @@ def forward(self, x): x = x + iden x = x.view(-1, 3, 3) return x - - + + class PointNetfeat(nn.Module): def __init__(self, num_points = 2500, global_feat = True): super(PointNetfeat, self).__init__() - self.stn = STN3d() + self.stn = STN3d(num_points = num_points) self.conv1 = torch.nn.Conv1d(3, 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv3 = torch.nn.Conv1d(128, 1024, 1) @@ -89,7 +88,7 @@ def forward(self, x): else: x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points) return torch.cat([x, pointfeat], 1), trans - + class PointNetCls(nn.Module): def __init__(self, num_points = 2500, k = 2): super(PointNetCls, self).__init__() @@ -121,7 +120,7 @@ def __init__(self, num_points = 2500, k = 2): self.bn1 = nn.BatchNorm1d(512) self.bn2 = nn.BatchNorm1d(256) self.bn3 = nn.BatchNorm1d(128) - + def forward(self, x): batchsize = x.size()[0] x, trans = self.feat(x) @@ -133,14 +132,14 @@ def forward(self, x): x = F.log_softmax(x.view(-1,self.k)) x = x.view(batchsize, self.num_points, self.k) return x, trans - + if __name__ == '__main__': sim_data = Variable(torch.rand(32,3,2500)) trans = STN3d() out = trans(sim_data) print('stn', out.size()) - + pointfeat = PointNetfeat(global_feat=True) out, _ = pointfeat(sim_data) print('global feat', out.size()) @@ -148,11 +147,11 @@ def forward(self, x): pointfeat = PointNetfeat(global_feat=False) out, _ = pointfeat(sim_data) print('point feat', out.size()) - + cls = PointNetCls(k = 5) out, _ = cls(sim_data) print('class', out.size()) - + seg = PointNetDenseCls(k = 3) out, _ = seg(sim_data) - print('seg', out.size()) \ No newline at end of file + print('seg', out.size()) diff --git a/train_classification.py b/train_classification.py index 726ccf868..b830bf7ce 100644 --- a/train_classification.py +++ b/train_classification.py @@ -21,6 +21,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--batchSize', type=int, default=32, help='input batch size') +parser.add_argument('--num_points', type=int, default=2500, help='input batch size') parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for') parser.add_argument('--outf', type=str, default='cls', help='output folder') @@ -36,11 +37,11 @@ random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) -dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True) +dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, npoints = opt.num_points) dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) -test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False) +test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False, npoints = opt.num_points) testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) @@ -54,13 +55,13 @@ pass -classifier = PointNetCls(k = num_classes) +classifier = PointNetCls(k = num_classes, num_points = opt.num_points) if opt.model != '': classifier.load_state_dict(torch.load(opt.model)) - - + + optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) classifier.cuda() @@ -70,8 +71,8 @@ for i, data in enumerate(dataloader, 0): points, target = data points, target = Variable(points), Variable(target[:,0]) - points = points.transpose(2,1) - points, target = points.cuda(), target.cuda() + points = points.transpose(2,1) + points, target = points.cuda(), target.cuda() optimizer.zero_grad() pred, _ = classifier(points) loss = F.nll_loss(pred, target) @@ -80,17 +81,17 @@ pred_choice = pred.data.max(1)[1] correct = pred_choice.eq(target.data).cpu().sum() print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.data[0], correct/float(opt.batchSize))) - + if i % 10 == 0: j, data = enumerate(testdataloader, 0).next() points, target = data points, target = Variable(points), Variable(target[:,0]) - points = points.transpose(2,1) - points, target = points.cuda(), target.cuda() + points = points.transpose(2,1) + points, target = points.cuda(), target.cuda() pred, _ = classifier(points) loss = F.nll_loss(pred, target) pred_choice = pred.data.max(1)[1] correct = pred_choice.eq(target.data).cpu().sum() print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.data[0], correct/float(opt.batchSize))) - - torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch)) \ No newline at end of file + + torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))