Skip to content

Commit

Permalink
[Example] Dynamic Graph CNN on Point Cloud (dmlc#789)
Browse files Browse the repository at this point in the history
* initial commit

* second commit

* another commit

* change docstring

* migrating to dgl.nn

* fixes

* docs

* lint

* multiple fixes

* doc
BarclayII authored Aug 28, 2019
1 parent e590fee commit dc19cd5
Showing 13 changed files with 654 additions and 22 deletions.
15 changes: 15 additions & 0 deletions docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
@@ -24,6 +24,10 @@ dgl.nn.pytorch.conv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.EdgeConv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.SAGEConv
:members: forward
:show-inheritance:
@@ -113,3 +117,14 @@ dgl.nn.pytorch.softmax

.. automodule:: dgl.nn.pytorch.softmax
:members: edge_softmax

dgl.nn.pytorch.factory
----------------------

.. automodule:: dgl.nn.pytorch.NearestNeighborGraph
:members:
:show-inheritance:

.. automodule:: dgl.nn.pytorch.SegmentedNearestNeighborGraph
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions docs/source/api/python/transform.rst
Original file line number Diff line number Diff line change
@@ -15,3 +15,5 @@ Transform -- Graph Transformation
khop_adj
khop_graph
laplacian_lambda_max
nearest_neighbor_graph
segmented_nearest_neighbor_graph
25 changes: 25 additions & 0 deletions examples/pytorch/pointcloud/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Dynamic EdgeConv
====

This is a reproduction of the paper [Dynamic Graph CNN for Learning on Point
Clouds](https://arxiv.org/pdf/1801.07829.pdf).

The reproduced experiment is the 40-class classification on the ModelNet40
dataset. The sampled point clouds are identical to that of
[PointNet](https://github.com/charlesq34/pointnet).

To train and test the model, simply run

```python
python main.py
```

The model currently takes 3 minutes to train an epoch on Tesla V100, and an
additional 17 seconds to run a validation and 20 seconds to run a test.

The best validation performance is 93.5% with a test performance of 91.8%.

## Dependencies

* `h5py`
* `tqdm`
131 changes: 131 additions & 0 deletions examples/pytorch/pointcloud/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from modelnet import ModelNet
from model import Model, compute_loss
from dgl.data.utils import download, get_download_dir

from functools import partial
import tqdm
import urllib
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=100)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--batch-size', type=int, default=32)
args = parser.parse_args()

num_workers = args.num_workers
batch_size = args.batch_size
data_filename = 'modelnet40-sampled-2048.h5'
local_path = args.dataset_path or os.path.join(get_download_dir(), data_filename)

if not os.path.exists(local_path):
download('https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/modelnet40-sampled-2048.h5', local_path)

CustomDataLoader = partial(
DataLoader,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)

def train(model, opt, scheduler, train_loader, dev):
scheduler.step()

model.train()

total_loss = 0
num_batches = 0
total_correct = 0
count = 0
with tqdm.tqdm(train_loader, ascii=True) as tq:
for data, label in tq:
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
opt.zero_grad()
logits = model(data)
loss = compute_loss(logits, label)
loss.backward()
opt.step()

_, preds = logits.max(1)

num_batches += 1
count += num_examples
loss = loss.item()
correct = (preds == label).sum().item()
total_loss += loss
total_correct += correct

tq.set_postfix({
'Loss': '%.5f' % loss,
'AvgLoss': '%.5f' % (total_loss / num_batches),
'Acc': '%.5f' % (correct / num_examples),
'AvgAcc': '%.5f' % (total_correct / count)})

def evaluate(model, test_loader, dev):
model.eval()

total_correct = 0
count = 0

with torch.no_grad():
with tqdm.tqdm(test_loader, ascii=True) as tq:
for data, label in tq:
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
logits = model(data)
_, preds = logits.max(1)

correct = (preds == label).sum().item()
total_correct += correct
count += num_examples

tq.set_postfix({
'Acc': '%.5f' % (correct / num_examples),
'AvgAcc': '%.5f' % (total_correct / count)})

return total_correct / count


dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model(20, [64, 64, 128, 256], [512, 512, 256], 40)
model = model.to(dev)
if args.load_model_path:
model.load_state_dict(torch.load(args.load_model_path, map_location=dev))

opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, args.num_epochs, eta_min=0.001)

modelnet = ModelNet(local_path, 1024)

train_loader = CustomDataLoader(modelnet.train())
valid_loader = CustomDataLoader(modelnet.valid())
test_loader = CustomDataLoader(modelnet.test())

best_valid_acc = 0
best_test_acc = 0

for epoch in range(args.num_epochs):
print('Epoch #%d Validating' % epoch)
valid_acc = evaluate(model, valid_loader, dev)
test_acc = evaluate(model, test_loader, dev)
if valid_acc > best_valid_acc:
best_valid_acc = valid_acc
best_test_acc = test_acc
if args.save_model_path:
torch.save(model.state_dict(), args.save_model_path)
print('Current validation acc: %.5f (best: %.5f), test acc: %.5f (best: %.5f)' % (
valid_acc, best_valid_acc, test_acc, best_test_acc))

train(model, opt, scheduler, train_loader, dev)
73 changes: 73 additions & 0 deletions examples/pytorch/pointcloud/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import NearestNeighborGraph, EdgeConv

class Model(nn.Module):
def __init__(self, k, feature_dims, emb_dims, output_classes, input_dims=3,
dropout_prob=0.5):
super(Model, self).__init__()

self.nng = NearestNeighborGraph(k)
self.conv = nn.ModuleList()

self.num_layers = len(feature_dims)
for i in range(self.num_layers):
self.conv.append(EdgeConv(
feature_dims[i - 1] if i > 0 else input_dims,
feature_dims[i],
batch_norm=True))

self.proj = nn.Linear(sum(feature_dims), emb_dims[0])

self.embs = nn.ModuleList()
self.bn_embs = nn.ModuleList()
self.dropouts = nn.ModuleList()

self.num_embs = len(emb_dims) - 1
for i in range(1, self.num_embs + 1):
self.embs.append(nn.Linear(
# * 2 because of concatenation of max- and mean-pooling
emb_dims[i - 1] if i > 1 else (emb_dims[i - 1] * 2),
emb_dims[i]))
self.bn_embs.append(nn.BatchNorm1d(emb_dims[i]))
self.dropouts.append(nn.Dropout(dropout_prob))

self.proj_output = nn.Linear(emb_dims[-1], output_classes)

def forward(self, x):
hs = []
batch_size, n_points, x_dims = x.shape
h = x

for i in range(self.num_layers):
g = self.nng(h)
h = h.view(batch_size * n_points, -1)
h = self.conv[i](g, h)
h = F.leaky_relu(h, 0.2)
h = h.view(batch_size, n_points, -1)
hs.append(h)

h = torch.cat(hs, 2)
h = self.proj(h)
h_max, _ = torch.max(h, 1)
h_avg = torch.mean(h, 1)
h = torch.cat([h_max, h_avg], 1)

for i in range(self.num_embs):
h = self.embs[i](h)
h = self.bn_embs[i](h)
h = F.leaky_relu(h, 0.2)
h = self.dropouts[i](h)

h = self.proj_output(h)
return h


def compute_loss(logits, y, eps=0.2):
num_classes = logits.shape[1]
one_hot = torch.zeros_like(logits).scatter_(1, y.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (num_classes - 1)
log_prob = F.log_softmax(logits, 1)
loss = -(one_hot * log_prob).sum(1).mean()
return loss
55 changes: 55 additions & 0 deletions examples/pytorch/pointcloud/modelnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
from torch.utils.data import Dataset

class ModelNet(object):
def __init__(self, path, num_points):
import h5py
self.f = h5py.File(path)
self.num_points = num_points

self.n_train = self.f['train/data'].shape[0]
self.n_valid = int(self.n_train / 5)
self.n_train -= self.n_valid
self.n_test = self.f['test/data'].shape[0]

def train(self):
return ModelNetDataset(self, 'train')

def valid(self):
return ModelNetDataset(self, 'valid')

def test(self):
return ModelNetDataset(self, 'test')

class ModelNetDataset(Dataset):
def __init__(self, modelnet, mode):
super(ModelNetDataset, self).__init__()
self.num_points = modelnet.num_points
self.mode = mode

if mode == 'train':
self.data = modelnet.f['train/data'][:modelnet.n_train]
self.label = modelnet.f['train/label'][:modelnet.n_train]
elif mode == 'valid':
self.data = modelnet.f['train/data'][modelnet.n_train:]
self.label = modelnet.f['train/label'][modelnet.n_train:]
elif mode == 'test':
self.data = modelnet.f['test/data'].value
self.label = modelnet.f['test/label'].value

def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2)):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[3])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[3])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32')
return x

def __len__(self):
return self.data.shape[0]

def __getitem__(self, i):
x = self.data[i][:self.num_points]
y = self.label[i]
if self.mode == 'train':
x = self.translate(x)
np.random.shuffle(x)
return x, y
42 changes: 41 additions & 1 deletion python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
@@ -291,7 +291,7 @@ def copy_to(input, ctx):
# These functions are performance critical, so it's better to have efficient
# implementation in each framework.

def sum(input, dim):
def sum(input, dim, keepdims=False):
"""Reduce sum the input tensor along the given dim.
Parameters
@@ -300,6 +300,8 @@ def sum(input, dim):
The input tensor.
dim : int
The reduce dim.
keepdims : bool
Whether to keep the summed dimension.
Returns
-------
@@ -447,13 +449,34 @@ def topk(input, k, dim, descending=True):
----------
input : Tensor
The input tensor.
k : int
The number of elements.
dim : int
The dim to sort along.
descending : bool
Controls whether to return largest/smallest elements.
"""
pass

def argtopk(input, k, dim, descending=True):
"""Return the indices of the k largest elements of the given input tensor
along the given dimension.
If descending is False then the k smallest elements are returned.
Parameters
----------
input : Tensor
The input tensor.
k : int
The number of elements.
dim : int
The dimension to sort along.
descending : bool
Controls whether to return largest/smallest elements.
"""
pass

def exp(input):
"""Returns a new tensor with the exponential of the elements of the input tensor `input`.
@@ -723,6 +746,23 @@ def reshape(input, shape):
"""
pass

def swapaxes(input, axis1, axis2):
"""Interchange the two given axes of a tensor.
Parameters
----------
input : Tensor
The input tensor.
axis1, axis2 : int
The two axes.
Returns
-------
Tensor
The transposed tensor.
"""
pass

def zeros(shape, dtype, ctx):
"""Create a zero tensor.
11 changes: 9 additions & 2 deletions python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
@@ -114,8 +114,8 @@ def asnumpy(input):
def copy_to(input, ctx):
return input.as_in_context(ctx)

def sum(input, dim):
return nd.sum(input, axis=dim)
def sum(input, dim, keepdims=False):
return nd.sum(input, axis=dim, keepdims=keepdims)

def reduce_sum(input):
return input.sum()
@@ -141,6 +141,10 @@ def reduce_min(input):
def topk(input, k, dim, descending=True):
return nd.topk(input, axis=dim, k=k, ret_typ='value', is_ascend=not descending)

def argtopk(input, k, dim, descending=True):
idx = nd.argsort(input, dim, is_ascend=not descending)
return nd.slice_axis(input, dim, 0, k)

def argsort(input, dim, descending):
idx = nd.argsort(input, dim, is_ascend=not descending)
idx = nd.cast(idx, dtype='int64')
@@ -220,6 +224,9 @@ def reshape(input, shape):
# NOTE: the input cannot be a symbol
return nd.reshape(input ,shape)

def swapaxes(input, axis1, axis2):
return nd.swapaxes(input, axis1, axis2)

def zeros(shape, dtype, ctx):
return nd.zeros(shape, dtype=dtype, ctx=ctx)

19 changes: 9 additions & 10 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
@@ -92,8 +92,8 @@ def copy_to(input, ctx):
else:
raise RuntimeError('Invalid context', ctx)

def sum(input, dim):
return th.sum(input, dim=dim)
def sum(input, dim, keepdims=False):
return th.sum(input, dim=dim, keepdim=keepdims)

def reduce_sum(input):
return input.sum()
@@ -124,6 +124,9 @@ def argsort(input, dim, descending):
def topk(input, k, dim, descending=True):
return th.topk(input, k, dim, largest=descending)[0]

def argtopk(input, k, dim, descending=True):
return th.topk(input, k, dim, largest=descending)[1]

def exp(input):
return th.exp(input)

@@ -149,14 +152,7 @@ def gather_row(data, row_index):
return th.index_select(data, 0, row_index)

def slice_axis(data, axis, begin, end):
dim = data.shape[axis]
if begin < 0:
begin += dim
if end <= 0:
end += dim
if begin >= end:
raise IndexError("Begin index ({}) equals or greater than end index ({})".format(begin, end))
return th.index_select(data, axis, th.arange(begin, end, device=data.device))
return th.narrow(data, axis, begin, end - begin)

def take(data, indices, dim):
new_shape = data.shape[:dim] + indices.shape + data.shape[dim+1:]
@@ -180,6 +176,9 @@ def unsqueeze(input, dim):
def reshape(input, shape):
return th.reshape(input ,shape)

def swapaxes(input, axis1, axis2):
return th.transpose(input, axis1, axis2)

def zeros(shape, dtype, ctx):
return th.zeros(shape, dtype=dtype, device=ctx)

1 change: 1 addition & 0 deletions python/dgl/nn/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -2,3 +2,4 @@
from .conv import *
from .glob import *
from .softmax import *
from .factory import *
93 changes: 88 additions & 5 deletions python/dgl/nn/pytorch/conv.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
__all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv']
'DenseChebConv', 'EdgeConv']

# pylint: disable=W0235
class Identity(nn.Module):
@@ -29,7 +29,6 @@ def forward(self, x):
"""Return input"""
return x


# pylint: enable=W0235
class GraphConv(nn.Module):
r"""Apply graph convolution over an input signal.
@@ -390,8 +389,8 @@ class RelGraphConv(nn.Module):
.. math::
h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}}
\sum_{j\in\mathcal{N}^r(i)}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})
h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}}
\sum_{j\in\mathcal{N}^r(i)}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})
where :math:`\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation
:math:`r`. :math:`c_{i,r}` is the normalizer equal
@@ -402,7 +401,7 @@ class RelGraphConv(nn.Module):
.. math::
W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}
W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}
where :math:`B` is the number of bases.
@@ -564,6 +563,90 @@ def forward(self, g, x, etypes, norm=None):
return node_repr


class EdgeConv(nn.Module):
r"""EdgeConv layer.
Introduced in "`Dynamic Graph CNN for Learning on Point Clouds
<https://arxiv.org/pdf/1801.07829>`__". Can be described as follows:
.. math::
x_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} \mathrm{ReLU}(
\Theta \cdot (x_j^{(l)} - x_i^{(l)}) + \Phi \cdot x_i^{(l)})
where :math:`\mathcal{N}(i)` is the neighbor of :math:`i`.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
batch_norm : bool
Whether to include batch normalization on messages.
"""
def __init__(self, in_feat, out_feat, batch_norm=False):
super(EdgeConv, self).__init__()
self.batch_norm = batch_norm

self.theta = nn.Linear(in_feat, out_feat)
self.phi = nn.Linear(in_feat, out_feat)

if batch_norm:
self.bn = nn.BatchNorm1d(out_feat)

def message(self, edges):
"""The message computation function.
"""
theta_x = self.theta(edges.dst['x'] - edges.src['x'])
phi_x = self.phi(edges.src['x'])
return {'e': theta_x + phi_x}

def forward(self, g, h):
"""Forward computation
Parameters
----------
g : DGLGraph
The graph.
h : Tensor
:math:`(N, D)` where :math:`N` is the number of nodes and
:math:`D` is the number of feature dimensions.
Returns
-------
torch.Tensor
New node features.
"""
with g.local_scope():
g.ndata['x'] = h
if not self.batch_norm:
g.update_all(self.message, fn.max('e', 'x'))
else:
g.apply_edges(self.message)
# Although the official implementation includes a per-edge
# batch norm within EdgeConv, I choose to replace it with a
# global batch norm for a number of reasons:
#
# (1) When the point clouds within each batch do not have the
# same number of points, batch norm would not work.
#
# (2) Even if the point clouds always have the same number of
# points, the points may as well be shuffled even with the
# same (type of) object (and the official implementation
# *does* shuffle the points of the same example for each
# epoch).
#
# For example, the first point of a point cloud of an
# airplane does not always necessarily reside at its nose.
#
# In this case, the learned statistics of each position
# by batch norm is not as meaningful as those learned from
# images.
g.edata['e'] = self.bn(g.edata['e'])
g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x'))
return g.ndata['x']


class SAGEConv(nn.Module):
r"""GraphSAGE layer from paper `Inductive Representation Learning on
Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__.
103 changes: 103 additions & 0 deletions python/dgl/nn/pytorch/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Modules that transforms between graphs and between graph and tensors."""
import torch.nn as nn
from ...transform import nearest_neighbor_graph, segmented_nearest_neighbor_graph

def pairwise_squared_distance(x):
'''
x : (n_samples, n_points, dims)
return : (n_samples, n_points, n_points)
'''
x2s = (x * x).sum(-1, keepdim=True)
return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2)


class NearestNeighborGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of
point sets with the same number of points into a union of those graphs.
If a batch of point set is provided, then the point :math:`j` in point
set :math:`i` is mapped to graph node ID :math:`i \times M + j`, where
:math:`M` is the number of nodes in each point set.
The predecessors of each node are the k-nearest neighbors of the
corresponding point.
Parameters
----------
k : int
The number of neighbors
"""
def __init__(self, k):
super(NearestNeighborGraph, self).__init__()
self.k = k

#pylint: disable=invalid-name
def forward(self, x):
"""Forward computation.
Parameters
----------
x : Tensor
:math:`(M, D)` or :math:`(N, M, D)` where :math:`N` means the
number of point sets, :math:`M` means the number of points in
each point set, and :math:`D` means the size of features.
Returns
-------
A DGLGraph with no features.
"""
return nearest_neighbor_graph(x, self.k)


class SegmentedNearestNeighborGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of
point sets with different number of points into a union of those graphs.
If a batch of point set is provided, then the point :math:`j` in point
set :math:`i` is mapped to graph node ID
:math:`\sum_{p<i} |V_p| + j`, where :math:`|V_p|` means the number of
points in point set :math:`p`.
The predecessors of each node are the k-nearest neighbors of the
corresponding point.
Parameters
----------
k : int
The number of neighbors
Inputs
------
x : Tensor
:math:`(M, D)` where :math:`M` means the total number of points
in all point sets.
segs : Tensor
:math:`(N)` integer tensors where :math:`N` means the number of
point sets. The elements must sum up to :math:`M`.
Outputs
-------
- A DGLGraph with no features.
"""
def __init__(self, k):
super(SegmentedNearestNeighborGraph, self).__init__()
self.k = k

#pylint: disable=invalid-name
def forward(self, x, segs):
"""Forward computation.
Parameters
----------
x : Tensor
:math:`(M, D)` where :math:`M` means the total number of points
in all point sets.
segs : iterable of int
:math:`(N)` integers where :math:`N` means the number of point
sets. The elements must sum up to :math:`M`.
Returns
-------
A DGLGraph with no features.
"""
return segmented_nearest_neighbor_graph(x, self.k, segs)
106 changes: 102 additions & 4 deletions python/dgl/transform.py
Original file line number Diff line number Diff line change
@@ -4,15 +4,113 @@
from scipy import sparse
from ._ffi.function import _init_api
from .graph import DGLGraph
from . import backend as F
from .graph_index import from_coo
from .batched_graph import BatchedDGLGraph, unbatch
from .backend import asnumpy, tensor


__all__ = ['line_graph', 'khop_adj', 'khop_graph', 'reverse', 'to_simple_graph', 'to_bidirected',
'laplacian_lambda_max']
'laplacian_lambda_max', 'nearest_neighbor_graph', 'segmented_nearest_neighbor_graph']


def pairwise_squared_distance(x):
"""
x : (n_samples, n_points, dims)
return : (n_samples, n_points, n_points)
"""
x2s = F.sum(x * x, -1, True)
# assuming that __matmul__ is always implemented (true for PyTorch, MXNet and Chainer)
return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2)

#pylint: disable=invalid-name
def nearest_neighbor_graph(x, k):
"""Transforms the given point set to a directed graph, whose coordinates
are given as a matrix. The predecessors of each point are its k-nearest
neighbors.
If a 3D tensor is given instead, then each row would be transformed into
a separate graph. The graphs will be unioned.
Parameters
----------
x : Tensor
The input tensor.
If 2D, each row of ``x`` corresponds to a node.
If 3D, a k-NN graph would be constructed for each row. Then
the graphs are unioned.
k : int
The number of neighbors
Returns
-------
DGLGraph
The graph. The node IDs are in the same order as ``x``.
"""
if F.ndim(x) == 2:
x = F.unsqueeze(x, 0)
n_samples, n_points, _ = F.shape(x)

dist = pairwise_squared_distance(x)
k_indices = F.argtopk(dist, k, 2, descending=False)
dst = F.copy_to(k_indices, F.cpu())

src = F.zeros_like(dst) + F.reshape(F.arange(0, n_points), (1, -1, 1))

per_sample_offset = F.reshape(F.arange(0, n_samples) * n_points, (-1, 1, 1))
dst += per_sample_offset
src += per_sample_offset
dst = F.reshape(dst, (-1,))
src = F.reshape(src, (-1,))
adj = sparse.csr_matrix((F.asnumpy(F.zeros_like(dst) + 1), (F.asnumpy(dst), F.asnumpy(src))))

g = DGLGraph(adj, readonly=True)
return g

#pylint: disable=invalid-name
def segmented_nearest_neighbor_graph(x, k, segs):
"""Transforms the given point set to a directed graph, whose coordinates
are given as a matrix. The predecessors of each point are its k-nearest
neighbors.
The matrices are concatenated along the first axis, and are segmented by
``segs``. Each block would be transformed into a separate graph. The
graphs will be unioned.
Parameters
----------
x : Tensor
The input tensor.
k : int
The number of neighbors
segs : iterable of int
Number of points of each point set.
Must sum up to the number of rows in ``x``.
Returns
-------
DGLGraph
The graph. The node IDs are in the same order as ``x``.
"""
n_total_points, _ = F.shape(x)
offset = np.insert(np.cumsum(segs), 0, 0)

h_list = F.split(x, segs, 0)
dst = [
F.argtopk(pairwise_squared_distance(h_g), k, 1, descending=False) +
offset[i]
for i, h_g in enumerate(h_list)]
dst = F.cat(dst, 0)
src = F.arange(0, n_total_points).unsqueeze(1).expand(n_total_points, k)

dst = F.reshape(dst, (-1,))
src = F.reshape(src, (-1,))
adj = sparse.csr_matrix((F.asnumpy(F.zeros_like(dst) + 1), (F.asnumpy(dst), F.asnumpy(src))))

g = DGLGraph(adj, readonly=True)
return g

def line_graph(g, backtracking=True, shared=False):
"""Return the line graph of this graph.
@@ -71,7 +169,7 @@ def khop_adj(g, k):
[0., 1., 3., 3., 1.]])
"""
adj_k = g.adjacency_matrix_scipy(return_edge_ids=False) ** k
return tensor(adj_k.todense().astype(np.float32))
return F.tensor(adj_k.todense().astype(np.float32))

def khop_graph(g, k):
"""Return the graph that includes all :math:`k`-hop neighbors of the given graph as edges.
@@ -299,7 +397,7 @@ def laplacian_lambda_max(g):
for g_i in g_arr:
n = g_i.number_of_nodes()
adj = g_i.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
norm = sparse.diags(asnumpy(g_i.in_degrees()).clip(1) ** -0.5, dtype=float)
norm = sparse.diags(F.asnumpy(g_i.in_degrees()).clip(1) ** -0.5, dtype=float)
laplacian = sparse.eye(n) - norm * adj * norm
rst.append(sparse.linalg.eigs(laplacian, 1, which='LM',
return_eigenvectors=False)[0].real)

0 comments on commit dc19cd5

Please sign in to comment.