Skip to content

Commit

Permalink
[Bug] Multiple fixes (dmlc#1374)
Browse files Browse the repository at this point in the history
* multiple fixes

* lint

* lint x2
  • Loading branch information
BarclayII authored Mar 19, 2020
1 parent 0a51dc5 commit 4af0202
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 71 deletions.
24 changes: 24 additions & 0 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,30 @@ def number_of_nodes(self):
"""
return self._graph.number_of_nodes()

def number_of_src_nodes(self):
"""Return the number of nodes in the graph.
For compatibility with heterographs.
Returns
-------
int
The number of nodes
"""
return self._graph.number_of_nodes()

def number_of_dst_nodes(self):
"""Return the number of nodes in the graph.
For compatibility with heterographs.
Returns
-------
int
The number of nodes
"""
return self._graph.number_of_nodes()

def __len__(self):
"""Return the number of nodes in the graph."""
return self.number_of_nodes()
Expand Down
28 changes: 20 additions & 8 deletions python/dgl/heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,12 @@ def ndata(self):
def srcdata(self):
"""Return the data view of all nodes in the SRC category.
**Only works if the graph is uni-bipartite and has one node type in the
SRC category.**
Only works if the graph is either
* Uni-bipartite and has one node type in the SRC category.
* Non-uni-bipartite and has only one node type (in this case identical to
:any:`DGLHeteroGraph.ndata`)
Examples
--------
Expand Down Expand Up @@ -750,8 +754,10 @@ def srcdata(self):
--------
nodes
"""
assert self.is_unibipartite, 'srcdata is only allowed for uni-bipartite graph.'
assert len(self.srctypes) == 1, 'srcdata is only allowed when there is only one SRC type.'
err_msg = (
'srcdata is only allowed when there is only one %s type.' %
('SRC' if self.is_unibipartite else 'node'))
assert len(self.srctypes) == 1, err_msg
ntype = self.srctypes[0]
ntid = self.get_ntype_id_from_src(ntype)
return HeteroNodeDataView(self, ntype, ntid, ALL)
Expand All @@ -760,8 +766,12 @@ def srcdata(self):
def dstdata(self):
"""Return the data view of all destination nodes.
**Only works if the graph is uni-bipartite and has one node type in the
DST category.**
Only works if the graph is either
* Uni-bipartite and has one node type in the SRC category.
* Non-uni-bipartite and has only one node type (in this case identical to
:any:`DGLHeteroGraph.ndata`)
Examples
--------
Expand Down Expand Up @@ -794,8 +804,10 @@ def dstdata(self):
--------
nodes
"""
assert self.is_unibipartite, 'dstdata is only allowed for uni-bipartite graph.'
assert len(self.dsttypes) == 1, 'dstdata is only allowed when there is only one DST type.'
err_msg = (
'dstdata is only allowed when there is only one %s type.' %
('DST' if self.is_unibipartite else 'node'))
assert len(self.dsttypes) == 1, err_msg
ntype = self.dsttypes[0]
ntid = self.get_ntype_id_from_dst(ntype)
return HeteroNodeDataView(self, ntype, ntid, ALL)
Expand Down
55 changes: 40 additions & 15 deletions python/dgl/nn/mxnet/conv/sageconv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""MXNet Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
from numbers import Integral
import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn
Expand All @@ -24,6 +25,14 @@ class SAGEConv(nn.Block):
----------
in_feats : int
Input feature size.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
If aggregator type is ``gcn``, the feature size of source and destination nodes
are required to be the same.
out_feats : int
Output feature size.
feat_drop : float
Expand All @@ -47,26 +56,34 @@ def __init__(self,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_feats = in_feats

if isinstance(in_feats, tuple):
self._in_src_feats = in_feats[0]
self._in_dst_feats = in_feats[1]
elif isinstance(in_feats, Integral):
self._in_src_feats = self._in_dst_feats = in_feats
else:
raise TypeError('in_feats must be either int or pair of ints')

self._out_feats = out_feats
self._aggre_type = aggregator_type
with self.name_scope():
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
if aggregator_type == 'pool':
self.fc_pool = nn.Dense(in_feats, use_bias=bias,
self.fc_pool = nn.Dense(self._in_src_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats)
in_units=self._in_src_feats)
if aggregator_type == 'lstm':
raise NotImplementedError
if aggregator_type != 'gcn':
self.fc_self = nn.Dense(out_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats)
in_units=self._in_dst_feats)
self.fc_neigh = nn.Dense(out_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats)
in_units=self._in_src_feats)

def forward(self, graph, feat):
r"""Compute GraphSAGE layer.
Expand All @@ -86,23 +103,31 @@ def forward(self, graph, feat):
is size of output feature.
"""
graph = graph.local_var()
feat = self.feat_drop(feat)
h_self = feat

if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)

h_self = feat_dst

if self._aggre_type == 'mean':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # saame as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# divide in degrees
degs = graph.in_degrees().astype(feat.dtype)
degs = degs.as_in_context(feat.context)
h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.expand_dims(-1) + 1)
degs = graph.in_degrees().astype(feat_dst.dtype)
degs = degs.as_in_context(feat_dst.context)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.expand_dims(-1) + 1)
elif self._aggre_type == 'pool':
graph.ndata['h'] = nd.relu(self.fc_pool(feat))
graph.srcdata['h'] = nd.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm':
raise NotImplementedError
else:
Expand Down
10 changes: 4 additions & 6 deletions python/dgl/nn/pytorch/conv/sageconv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from numbers import Integral
import torch
from torch import nn
from torch.nn import functional as F

Expand Down Expand Up @@ -124,11 +123,11 @@ def forward(self, graph, feat):
"""
graph = graph.local_var()

if torch.is_tensor(feat):
feat_src = feat_dst = self.feat_drop(feat)
else:
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)

h_self = feat_dst

Expand All @@ -141,8 +140,7 @@ def forward(self, graph, feat):
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().float()
degs = degs.to(feat_dst.device)
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/nn/pytorch/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def forward(ctx, g, score, eids):
if not is_all(eids):
g = g.edge_subgraph(eids.long())

n_nodes = g.number_of_nodes()
n_nodes = g.number_of_dst_nodes()
n_edges = g.number_of_edges()

# TODO(BarclayII): this is a temporary fix of memory leakage in PyTorch
Expand Down
53 changes: 39 additions & 14 deletions python/dgl/nn/tensorflow/conv/sageconv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tensorflow Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from numbers import Integral
import tensorflow as tf
from tensorflow.keras import layers

Expand All @@ -21,8 +22,16 @@ class SAGEConv(layers.Layer):
Parameters
----------
in_feats : int
in_feats : int, or pair of ints
Input feature size.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
If aggregator type is ``gcn``, the feature size of source and destination nodes
are required to be the same.
out_feats : int
Output feature size.
feat_drop : float
Expand All @@ -47,17 +56,25 @@ def __init__(self,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_feats = in_feats

if isinstance(in_feats, tuple):
self._in_src_feats = in_feats[0]
self._in_dst_feats = in_feats[1]
elif isinstance(in_feats, Integral):
self._in_src_feats = self._in_dst_feats = in_feats
else:
raise TypeError('in_feats must be either int or pair of ints')

self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.feat_drop = layers.Dropout(feat_drop)
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
self.fc_pool = layers.Dense(in_feats)
self.fc_pool = layers.Dense(self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = layers.LSTM(units=in_feats)
self.lstm = layers.LSTM(units=self._in_src_feats)
if aggregator_type != 'gcn':
self.fc_self = layers.Dense(out_feats, use_bias=bias)
self.fc_neigh = layers.Dense(out_feats, use_bias=bias)
Expand Down Expand Up @@ -89,27 +106,35 @@ def call(self, graph, feat):
is size of output feature.
"""
graph = graph.local_var()
feat = self.feat_drop(feat)
h_self = feat

if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)

h_self = feat_dst

if self._aggre_type == 'mean':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = tf.cast(graph.in_degrees(), tf.float32)
h_neigh = (graph.ndata['neigh'] + graph.ndata['h']
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']
) / (tf.expand_dims(degs, -1) + 1)
elif self._aggre_type == 'pool':
graph.ndata['h'] = tf.nn.relu(self.fc_pool(feat))
graph.srcdata['h'] = tf.nn.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
else:
raise KeyError(
'Aggregator type {} not recognized.'.format(self._aggre_type))
Expand Down
33 changes: 23 additions & 10 deletions tests/mxnet/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,30 @@ def test_gat_conv():
assert h1.shape == (20, 5, 20)

def test_sage_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx()

graphsage = nn.SAGEConv(10, 20)
graphsage.initialize(ctx=ctx)
print(graphsage)
for aggre_type in ['mean', 'pool', 'gcn']:
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((100, 5))
sage.initialize(ctx=ctx)
h = sage(g, feat)
assert h.shape[-1] == 10

# test#1: basic
h0 = F.randn((20, 10))
h1 = graphsage(g, h0)
assert h1.shape == (20, 20)
g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((100, 5))
sage.initialize(ctx=ctx)
h = sage(g, feat)
assert h.shape[-1] == 10

g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
dst_dim = 5 if aggre_type != 'gcn' else 10
sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
sage.initialize(ctx=ctx)
h = sage(g, feat)
assert h.shape[-1] == 2
assert h.shape[0] == 200

def test_gg_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
Expand Down
Loading

0 comments on commit 4af0202

Please sign in to comment.