Skip to content

Commit

Permalink
fix import
Browse files Browse the repository at this point in the history
  • Loading branch information
fxia22 committed May 8, 2017
1 parent e08c96e commit bf3494c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 42 deletions.
32 changes: 16 additions & 16 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import progressbar
import sys
import torchvision.transforms as transforms
import utils
import argparse
import json

Expand All @@ -22,17 +21,17 @@ def __init__(self, root, npoints = 2500, classification = False, class_choice =
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {}

self.classification = classification

with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
#print(self.cat)
if not class_choice is None:
self.cat = {k:v for k,v in self.cat.items() if k in class_choice}

self.meta = {}
for item in self.cat:
#print('category', item)
Expand All @@ -45,35 +44,36 @@ def __init__(self, root, npoints = 2500, classification = False, class_choice =
fns = fns[:int(len(fns) * 0.9)]
else:
fns = fns[int(len(fns) * 0.9):]

#print(os.path.basename(fns))
for fn in fns:
token = (os.path.splitext(os.path.basename(fn))[0])
token = (os.path.splitext(os.path.basename(fn))[0])
self.meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))

self.datapath = []
for item in self.cat:
for fn in self.meta[item]:
self.datapath.append((item, fn[0], fn[1]))


self.classes = dict(zip(self.cat, range(len(self.cat))))


self.classes = dict(zip(self.cat, range(len(self.cat))))
print(self.classes)
self.num_seg_classes = 0
if not self.classification:
for i in range(len(self.datapath)/50):
l = len(np.unique(np.loadtxt(self.datapath[i][-1]).astype(np.uint8)))
if l > self.num_seg_classes:
self.num_seg_classes = l
#print(self.num_seg_classes)


def __getitem__(self, index):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
point_set = np.loadtxt(fn[1]).astype(np.float32)
seg = np.loadtxt(fn[2]).astype(np.int64)
#print(point_set.shape, seg.shape)

choice = np.random.choice(len(seg), self.npoints, replace=True)
#resample
point_set = point_set[choice, :]
Expand All @@ -85,7 +85,7 @@ def __getitem__(self, index):
return point_set, cls
else:
return point_set, seg

def __len__(self):
return len(self.datapath)

Expand All @@ -96,8 +96,8 @@ def __len__(self):
print(len(d))
ps, seg = d[0]
print(ps.size(), ps.type(), seg.size(),seg.type())

d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True)
print(len(d))
ps, cls = d[0]
print(ps.size(), ps.type(), cls.size(),cls.type())
print(ps.size(), ps.type(), cls.size(),cls.type())
27 changes: 13 additions & 14 deletions pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import numpy as np
import matplotlib.pyplot as plt
import pdb
import utils
import torch.nn.functional as F


Expand All @@ -31,14 +30,14 @@ def __init__(self, num_points = 2500):
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 9)
self.relu = nn.ReLU()

self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)


def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.bn1(self.conv1(x)))
Expand All @@ -57,12 +56,12 @@ def forward(self, x):
x = x + iden
x = x.view(-1, 3, 3)
return x


class PointNetfeat(nn.Module):
def __init__(self, num_points = 2500, global_feat = True):
super(PointNetfeat, self).__init__()
self.stn = STN3d()
self.stn = STN3d(num_points = num_points)
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
Expand All @@ -89,7 +88,7 @@ def forward(self, x):
else:
x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
return torch.cat([x, pointfeat], 1), trans

class PointNetCls(nn.Module):
def __init__(self, num_points = 2500, k = 2):
super(PointNetCls, self).__init__()
Expand Down Expand Up @@ -121,7 +120,7 @@ def __init__(self, num_points = 2500, k = 2):
self.bn1 = nn.BatchNorm1d(512)
self.bn2 = nn.BatchNorm1d(256)
self.bn3 = nn.BatchNorm1d(128)

def forward(self, x):
batchsize = x.size()[0]
x, trans = self.feat(x)
Expand All @@ -133,26 +132,26 @@ def forward(self, x):
x = F.log_softmax(x.view(-1,self.k))
x = x.view(batchsize, self.num_points, self.k)
return x, trans


if __name__ == '__main__':
sim_data = Variable(torch.rand(32,3,2500))
trans = STN3d()
out = trans(sim_data)
print('stn', out.size())

pointfeat = PointNetfeat(global_feat=True)
out, _ = pointfeat(sim_data)
print('global feat', out.size())

pointfeat = PointNetfeat(global_feat=False)
out, _ = pointfeat(sim_data)
print('point feat', out.size())

cls = PointNetCls(k = 5)
out, _ = cls(sim_data)
print('class', out.size())

seg = PointNetDenseCls(k = 3)
out, _ = seg(sim_data)
print('seg', out.size())
print('seg', out.size())
25 changes: 13 additions & 12 deletions train_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=32, help='input batch size')
parser.add_argument('--num_points', type=int, default=2500, help='input batch size')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--outf', type=str, default='cls', help='output folder')
Expand All @@ -36,11 +37,11 @@
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True)
dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, npoints = opt.num_points)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))

test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False)
test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False, npoints = opt.num_points)
testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))

Expand All @@ -54,13 +55,13 @@
pass


classifier = PointNetCls(k = num_classes)
classifier = PointNetCls(k = num_classes, num_points = opt.num_points)


if opt.model != '':
classifier.load_state_dict(torch.load(opt.model))


optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
classifier.cuda()

Expand All @@ -70,8 +71,8 @@
for i, data in enumerate(dataloader, 0):
points, target = data
points, target = Variable(points), Variable(target[:,0])
points = points.transpose(2,1)
points, target = points.cuda(), target.cuda()
points = points.transpose(2,1)
points, target = points.cuda(), target.cuda()
optimizer.zero_grad()
pred, _ = classifier(points)
loss = F.nll_loss(pred, target)
Expand All @@ -80,17 +81,17 @@
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.data[0], correct/float(opt.batchSize)))

if i % 10 == 0:
j, data = enumerate(testdataloader, 0).next()
points, target = data
points, target = Variable(points), Variable(target[:,0])
points = points.transpose(2,1)
points, target = points.cuda(), target.cuda()
points = points.transpose(2,1)
points, target = points.cuda(), target.cuda()
pred, _ = classifier(points)
loss = F.nll_loss(pred, target)
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.data[0], correct/float(opt.batchSize)))
torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))

torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))

0 comments on commit bf3494c

Please sign in to comment.