-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Dynamic Graph CNN on Point Cloud (dmlc#789)
* initial commit * second commit * another commit * change docstring * migrating to dgl.nn * fixes * docs * lint * multiple fixes * doc
Showing
13 changed files
with
654 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
Dynamic EdgeConv | ||
==== | ||
|
||
This is a reproduction of the paper [Dynamic Graph CNN for Learning on Point | ||
Clouds](https://arxiv.org/pdf/1801.07829.pdf). | ||
|
||
The reproduced experiment is the 40-class classification on the ModelNet40 | ||
dataset. The sampled point clouds are identical to that of | ||
[PointNet](https://github.com/charlesq34/pointnet). | ||
|
||
To train and test the model, simply run | ||
|
||
```python | ||
python main.py | ||
``` | ||
|
||
The model currently takes 3 minutes to train an epoch on Tesla V100, and an | ||
additional 17 seconds to run a validation and 20 seconds to run a test. | ||
|
||
The best validation performance is 93.5% with a test performance of 91.8%. | ||
|
||
## Dependencies | ||
|
||
* `h5py` | ||
* `tqdm` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torch.utils.data import DataLoader | ||
from modelnet import ModelNet | ||
from model import Model, compute_loss | ||
from dgl.data.utils import download, get_download_dir | ||
|
||
from functools import partial | ||
import tqdm | ||
import urllib | ||
import os | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--dataset-path', type=str, default='') | ||
parser.add_argument('--load-model-path', type=str, default='') | ||
parser.add_argument('--save-model-path', type=str, default='') | ||
parser.add_argument('--num-epochs', type=int, default=100) | ||
parser.add_argument('--num-workers', type=int, default=0) | ||
parser.add_argument('--batch-size', type=int, default=32) | ||
args = parser.parse_args() | ||
|
||
num_workers = args.num_workers | ||
batch_size = args.batch_size | ||
data_filename = 'modelnet40-sampled-2048.h5' | ||
local_path = args.dataset_path or os.path.join(get_download_dir(), data_filename) | ||
|
||
if not os.path.exists(local_path): | ||
download('https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/modelnet40-sampled-2048.h5', local_path) | ||
|
||
CustomDataLoader = partial( | ||
DataLoader, | ||
num_workers=num_workers, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
drop_last=True) | ||
|
||
def train(model, opt, scheduler, train_loader, dev): | ||
scheduler.step() | ||
|
||
model.train() | ||
|
||
total_loss = 0 | ||
num_batches = 0 | ||
total_correct = 0 | ||
count = 0 | ||
with tqdm.tqdm(train_loader, ascii=True) as tq: | ||
for data, label in tq: | ||
num_examples = label.shape[0] | ||
data, label = data.to(dev), label.to(dev).squeeze().long() | ||
opt.zero_grad() | ||
logits = model(data) | ||
loss = compute_loss(logits, label) | ||
loss.backward() | ||
opt.step() | ||
|
||
_, preds = logits.max(1) | ||
|
||
num_batches += 1 | ||
count += num_examples | ||
loss = loss.item() | ||
correct = (preds == label).sum().item() | ||
total_loss += loss | ||
total_correct += correct | ||
|
||
tq.set_postfix({ | ||
'Loss': '%.5f' % loss, | ||
'AvgLoss': '%.5f' % (total_loss / num_batches), | ||
'Acc': '%.5f' % (correct / num_examples), | ||
'AvgAcc': '%.5f' % (total_correct / count)}) | ||
|
||
def evaluate(model, test_loader, dev): | ||
model.eval() | ||
|
||
total_correct = 0 | ||
count = 0 | ||
|
||
with torch.no_grad(): | ||
with tqdm.tqdm(test_loader, ascii=True) as tq: | ||
for data, label in tq: | ||
num_examples = label.shape[0] | ||
data, label = data.to(dev), label.to(dev).squeeze().long() | ||
logits = model(data) | ||
_, preds = logits.max(1) | ||
|
||
correct = (preds == label).sum().item() | ||
total_correct += correct | ||
count += num_examples | ||
|
||
tq.set_postfix({ | ||
'Acc': '%.5f' % (correct / num_examples), | ||
'AvgAcc': '%.5f' % (total_correct / count)}) | ||
|
||
return total_correct / count | ||
|
||
|
||
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
model = Model(20, [64, 64, 128, 256], [512, 512, 256], 40) | ||
model = model.to(dev) | ||
if args.load_model_path: | ||
model.load_state_dict(torch.load(args.load_model_path, map_location=dev)) | ||
|
||
opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) | ||
|
||
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, args.num_epochs, eta_min=0.001) | ||
|
||
modelnet = ModelNet(local_path, 1024) | ||
|
||
train_loader = CustomDataLoader(modelnet.train()) | ||
valid_loader = CustomDataLoader(modelnet.valid()) | ||
test_loader = CustomDataLoader(modelnet.test()) | ||
|
||
best_valid_acc = 0 | ||
best_test_acc = 0 | ||
|
||
for epoch in range(args.num_epochs): | ||
print('Epoch #%d Validating' % epoch) | ||
valid_acc = evaluate(model, valid_loader, dev) | ||
test_acc = evaluate(model, test_loader, dev) | ||
if valid_acc > best_valid_acc: | ||
best_valid_acc = valid_acc | ||
best_test_acc = test_acc | ||
if args.save_model_path: | ||
torch.save(model.state_dict(), args.save_model_path) | ||
print('Current validation acc: %.5f (best: %.5f), test acc: %.5f (best: %.5f)' % ( | ||
valid_acc, best_valid_acc, test_acc, best_test_acc)) | ||
|
||
train(model, opt, scheduler, train_loader, dev) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from dgl.nn.pytorch import NearestNeighborGraph, EdgeConv | ||
|
||
class Model(nn.Module): | ||
def __init__(self, k, feature_dims, emb_dims, output_classes, input_dims=3, | ||
dropout_prob=0.5): | ||
super(Model, self).__init__() | ||
|
||
self.nng = NearestNeighborGraph(k) | ||
self.conv = nn.ModuleList() | ||
|
||
self.num_layers = len(feature_dims) | ||
for i in range(self.num_layers): | ||
self.conv.append(EdgeConv( | ||
feature_dims[i - 1] if i > 0 else input_dims, | ||
feature_dims[i], | ||
batch_norm=True)) | ||
|
||
self.proj = nn.Linear(sum(feature_dims), emb_dims[0]) | ||
|
||
self.embs = nn.ModuleList() | ||
self.bn_embs = nn.ModuleList() | ||
self.dropouts = nn.ModuleList() | ||
|
||
self.num_embs = len(emb_dims) - 1 | ||
for i in range(1, self.num_embs + 1): | ||
self.embs.append(nn.Linear( | ||
# * 2 because of concatenation of max- and mean-pooling | ||
emb_dims[i - 1] if i > 1 else (emb_dims[i - 1] * 2), | ||
emb_dims[i])) | ||
self.bn_embs.append(nn.BatchNorm1d(emb_dims[i])) | ||
self.dropouts.append(nn.Dropout(dropout_prob)) | ||
|
||
self.proj_output = nn.Linear(emb_dims[-1], output_classes) | ||
|
||
def forward(self, x): | ||
hs = [] | ||
batch_size, n_points, x_dims = x.shape | ||
h = x | ||
|
||
for i in range(self.num_layers): | ||
g = self.nng(h) | ||
h = h.view(batch_size * n_points, -1) | ||
h = self.conv[i](g, h) | ||
h = F.leaky_relu(h, 0.2) | ||
h = h.view(batch_size, n_points, -1) | ||
hs.append(h) | ||
|
||
h = torch.cat(hs, 2) | ||
h = self.proj(h) | ||
h_max, _ = torch.max(h, 1) | ||
h_avg = torch.mean(h, 1) | ||
h = torch.cat([h_max, h_avg], 1) | ||
|
||
for i in range(self.num_embs): | ||
h = self.embs[i](h) | ||
h = self.bn_embs[i](h) | ||
h = F.leaky_relu(h, 0.2) | ||
h = self.dropouts[i](h) | ||
|
||
h = self.proj_output(h) | ||
return h | ||
|
||
|
||
def compute_loss(logits, y, eps=0.2): | ||
num_classes = logits.shape[1] | ||
one_hot = torch.zeros_like(logits).scatter_(1, y.view(-1, 1), 1) | ||
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (num_classes - 1) | ||
log_prob = F.log_softmax(logits, 1) | ||
loss = -(one_hot * log_prob).sum(1).mean() | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
|
||
class ModelNet(object): | ||
def __init__(self, path, num_points): | ||
import h5py | ||
self.f = h5py.File(path) | ||
self.num_points = num_points | ||
|
||
self.n_train = self.f['train/data'].shape[0] | ||
self.n_valid = int(self.n_train / 5) | ||
self.n_train -= self.n_valid | ||
self.n_test = self.f['test/data'].shape[0] | ||
|
||
def train(self): | ||
return ModelNetDataset(self, 'train') | ||
|
||
def valid(self): | ||
return ModelNetDataset(self, 'valid') | ||
|
||
def test(self): | ||
return ModelNetDataset(self, 'test') | ||
|
||
class ModelNetDataset(Dataset): | ||
def __init__(self, modelnet, mode): | ||
super(ModelNetDataset, self).__init__() | ||
self.num_points = modelnet.num_points | ||
self.mode = mode | ||
|
||
if mode == 'train': | ||
self.data = modelnet.f['train/data'][:modelnet.n_train] | ||
self.label = modelnet.f['train/label'][:modelnet.n_train] | ||
elif mode == 'valid': | ||
self.data = modelnet.f['train/data'][modelnet.n_train:] | ||
self.label = modelnet.f['train/label'][modelnet.n_train:] | ||
elif mode == 'test': | ||
self.data = modelnet.f['test/data'].value | ||
self.label = modelnet.f['test/label'].value | ||
|
||
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2)): | ||
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[3]) | ||
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[3]) | ||
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32') | ||
return x | ||
|
||
def __len__(self): | ||
return self.data.shape[0] | ||
|
||
def __getitem__(self, i): | ||
x = self.data[i][:self.num_points] | ||
y = self.label[i] | ||
if self.mode == 'train': | ||
x = self.translate(x) | ||
np.random.shuffle(x) | ||
return x, y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
from .conv import * | ||
from .glob import * | ||
from .softmax import * | ||
from .factory import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
"""Modules that transforms between graphs and between graph and tensors.""" | ||
import torch.nn as nn | ||
from ...transform import nearest_neighbor_graph, segmented_nearest_neighbor_graph | ||
|
||
def pairwise_squared_distance(x): | ||
''' | ||
x : (n_samples, n_points, dims) | ||
return : (n_samples, n_points, n_points) | ||
''' | ||
x2s = (x * x).sum(-1, keepdim=True) | ||
return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2) | ||
|
||
|
||
class NearestNeighborGraph(nn.Module): | ||
r"""Layer that transforms one point set into a graph, or a batch of | ||
point sets with the same number of points into a union of those graphs. | ||
If a batch of point set is provided, then the point :math:`j` in point | ||
set :math:`i` is mapped to graph node ID :math:`i \times M + j`, where | ||
:math:`M` is the number of nodes in each point set. | ||
The predecessors of each node are the k-nearest neighbors of the | ||
corresponding point. | ||
Parameters | ||
---------- | ||
k : int | ||
The number of neighbors | ||
""" | ||
def __init__(self, k): | ||
super(NearestNeighborGraph, self).__init__() | ||
self.k = k | ||
|
||
#pylint: disable=invalid-name | ||
def forward(self, x): | ||
"""Forward computation. | ||
Parameters | ||
---------- | ||
x : Tensor | ||
:math:`(M, D)` or :math:`(N, M, D)` where :math:`N` means the | ||
number of point sets, :math:`M` means the number of points in | ||
each point set, and :math:`D` means the size of features. | ||
Returns | ||
------- | ||
A DGLGraph with no features. | ||
""" | ||
return nearest_neighbor_graph(x, self.k) | ||
|
||
|
||
class SegmentedNearestNeighborGraph(nn.Module): | ||
r"""Layer that transforms one point set into a graph, or a batch of | ||
point sets with different number of points into a union of those graphs. | ||
If a batch of point set is provided, then the point :math:`j` in point | ||
set :math:`i` is mapped to graph node ID | ||
:math:`\sum_{p<i} |V_p| + j`, where :math:`|V_p|` means the number of | ||
points in point set :math:`p`. | ||
The predecessors of each node are the k-nearest neighbors of the | ||
corresponding point. | ||
Parameters | ||
---------- | ||
k : int | ||
The number of neighbors | ||
Inputs | ||
------ | ||
x : Tensor | ||
:math:`(M, D)` where :math:`M` means the total number of points | ||
in all point sets. | ||
segs : Tensor | ||
:math:`(N)` integer tensors where :math:`N` means the number of | ||
point sets. The elements must sum up to :math:`M`. | ||
Outputs | ||
------- | ||
- A DGLGraph with no features. | ||
""" | ||
def __init__(self, k): | ||
super(SegmentedNearestNeighborGraph, self).__init__() | ||
self.k = k | ||
|
||
#pylint: disable=invalid-name | ||
def forward(self, x, segs): | ||
"""Forward computation. | ||
Parameters | ||
---------- | ||
x : Tensor | ||
:math:`(M, D)` where :math:`M` means the total number of points | ||
in all point sets. | ||
segs : iterable of int | ||
:math:`(N)` integers where :math:`N` means the number of point | ||
sets. The elements must sum up to :math:`M`. | ||
Returns | ||
------- | ||
A DGLGraph with no features. | ||
""" | ||
return segmented_nearest_neighbor_graph(x, self.k, segs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters