Skip to content

Commit

Permalink
[Bug Fix] Fix the case when reverse_edge is False for citation graphs (
Browse files Browse the repository at this point in the history
…dmlc#3840)

* Update citation_graph.py

* Update

* Update

* Update

Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
mufeili and jermainewang authored Jun 22, 2022
1 parent 71157b0 commit 4d3c01d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 74 deletions.
14 changes: 5 additions & 9 deletions python/dgl/data/citation_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
}

def __init__(self, name, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None,
verbose=True, reverse_edge=True, transform=None,
reorder=False):
assert name.lower() in ['cora', 'citeseer', 'pubmed']

Expand Down Expand Up @@ -122,8 +122,12 @@ def process(self):

if self.reverse_edge:
graph = nx.DiGraph(nx.from_dict_of_lists(graph))
g = from_networkx(graph)
else:
graph = nx.Graph(nx.from_dict_of_lists(graph))
edges = list(graph.edges())
u, v = map(list, zip(*edges))
g = dgl_graph((u, v))

onehot_labels = np.vstack((ally, ty))
onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :]
Expand All @@ -137,9 +141,6 @@ def process(self):
val_mask = generate_mask_tensor(_sample_mask(idx_val, labels.shape[0]))
test_mask = generate_mask_tensor(_sample_mask(idx_test, labels.shape[0]))

self._graph = graph
g = from_networkx(graph)

g.ndata['train_mask'] = train_mask
g.ndata['val_mask'] = val_mask
g.ndata['test_mask'] = test_mask
Expand Down Expand Up @@ -204,7 +205,6 @@ def load(self):
graph.ndata.pop('feat')
graph.ndata.pop('label')
graph = to_networkx(graph)
self._graph = nx.DiGraph(graph)

self._num_classes = info['num_classes']
self._g.ndata['train_mask'] = generate_mask_tensor(F.asnumpy(self._g.ndata['train_mask']))
Expand Down Expand Up @@ -250,10 +250,6 @@ def num_classes(self):
""" Citation graph is used in many examples
We preserve these properties for compatability.
"""
@property
def graph(self):
deprecate_property('dataset.graph', 'dataset[0]')
return self._graph

@property
def train_mask(self):
Expand Down
106 changes: 53 additions & 53 deletions tutorials/models/1_gnn/6_line_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@
"""

###########################################################################################
#
#
# In this tutorial, you learn how to solve community detection tasks by implementing a line
# graph neural network (LGNN). Community detection, or graph clustering, consists of partitioning
# the vertices in a graph into clusters in which nodes are more similar to
# one another.
#
#
# In the :doc:`Graph convolutinal network tutorial <1_gcn>`, you learned how to classify the nodes of an input
# graph in a semi-supervised setting. You used a graph convolutional neural network (GCN)
# as an embedding mechanism for graph features.
#
# To generalize a graph neural network (GNN) into supervised community detection, a line-graph based
# variation of GNN is introduced in the research paper
# `Supervised Community Detection with Line Graph Neural Networks <https://arxiv.org/abs/1705.08415>`__.
#
# To generalize a graph neural network (GNN) into supervised community detection, a line-graph based
# variation of GNN is introduced in the research paper
# `Supervised Community Detection with Line Graph Neural Networks <https://arxiv.org/abs/1705.08415>`__.
# One of the highlights of the model is
# to augment the straightforward GNN architecture so that it operates on
# a line graph of edge adjacencies, defined with a non-backtracking operator.
#
# A line graph neural network (LGNN) shows how DGL can implement an advanced graph algorithm by
# A line graph neural network (LGNN) shows how DGL can implement an advanced graph algorithm by
# mixing basic tensor operations, sparse-matrix multiplication, and message-
# passing APIs.
#
Expand Down Expand Up @@ -65,13 +65,13 @@
#
# Cora dataset
# ~~~~~
# To be consistent with the GCN tutorial,
# you use the `Cora dataset <https://linqs.soe.ucsc.edu/data>`__
# to illustrate a simple community detection task. Cora is a scientific publication dataset,
# with 2708 papers belonging to seven
# different machine learning fields. Here, you formulate Cora as a
# directed graph, with each node being a paper, and each edge being a
# citation link (A->B means A cites B). Here is a visualization of the whole
# To be consistent with the GCN tutorial,
# you use the `Cora dataset <https://linqs.soe.ucsc.edu/data>`__
# to illustrate a simple community detection task. Cora is a scientific publication dataset,
# with 2708 papers belonging to seven
# different machine learning fields. Here, you formulate Cora as a
# directed graph, with each node being a paper, and each edge being a
# citation link (A->B means A cites B). Here is a visualization of the whole
# Cora dataset.
#
# .. figure:: https://i.imgur.com/X404Byc.png
Expand All @@ -96,7 +96,7 @@

data = citegrh.load_cora()

G = dgl.DGLGraph(data.graph)
G = data[0]
labels = th.tensor(data.labels)

# find all the nodes labeled with class 0
Expand All @@ -113,7 +113,7 @@
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Without loss of generality, in this tutorial you limit the scope of the
# task to binary community detection.
#
#
# .. note::
#
# To create a practice binary-community dataset from Cora, first extract
Expand Down Expand Up @@ -177,7 +177,7 @@ def visualize(labels, g):
# community assignment :math:`\{A, A, A, B\}`, with each node's label
# :math:`l \in \{0,1\}`,The group of all possible permutations
# :math:`S_c = \{\{0,0,0,1\}, \{1,1,1,0\}\}`.
#
#
# Line graph neural network key ideas
# ------------------------------------
# An key innovation in this topic is the use of a line graph.
Expand All @@ -193,7 +193,7 @@ def visualize(labels, g):
# Specifically, a line-graph :math:`L(G)` turns an edge of the original graph `G`
# into a node. This is illustrated with the graph below (taken from the
# research paper).
#
#
# .. figure:: https://i.imgur.com/4WO5jEm.png
# :alt: lg
# :align: center
Expand All @@ -206,11 +206,11 @@ def visualize(labels, g):
# connect two edges? Here, we use the following connection rule:
#
# Two nodes :math:`v^{l}_{A}`, :math:`v^{l}_{B}` in `lg` are connected if
# the corresponding two edges :math:`e_{A}, e_{B}` in `g` share one and only
# the corresponding two edges :math:`e_{A}, e_{B}` in `g` share one and only
# one node:
# :math:`e_{A}`'s destination node is :math:`e_{B}`'s source node
# (:math:`j`).
#
#
# .. note::
#
# Mathematically, this definition corresponds to a notion called non-backtracking
Expand All @@ -228,7 +228,7 @@ def visualize(labels, g):
# LGNN chains together a series of line graph neural network layers. The graph
# representation :math:`x` and its line graph companion :math:`y` evolve with
# the dataflow as follows.
#
#
# .. figure:: https://i.imgur.com/bZGGIGp.png
# :alt: alg
# :align: center
Expand Down Expand Up @@ -265,9 +265,9 @@ def visualize(labels, g):
#
# Implement LGNN in DGL
# ---------------------
# Even though the equations in the previous section might seem intimidating,
# Even though the equations in the previous section might seem intimidating,
# it helps to understand the following information before you implement the LGNN.
#
#
# The two equations are symmetric and can be implemented as two instances
# of the same class with different parameters.
# The first equation operates on graph representation :math:`x`,
Expand Down Expand Up @@ -295,7 +295,7 @@ def visualize(labels, g):
# Each of the terms are performed again with different
# parameters, and without the nonlinearity after the sum.
# Therefore, :math:`f` could be written as:
#
#
# .. math::
# \begin{split}
# f(x^{(k)},y^{(k)}) = {}\rho[&\text{prev}(x^{(k-1)}) + \text{deg}(x^{(k-1)}) +\text{radius}(x^{k-1})
Expand All @@ -304,18 +304,18 @@ def visualize(labels, g):
# \end{split}
#
# Two equations are chained-up in the following order:
#
#
# .. math::
# \begin{split}
# x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\
# y^{(k+1)} = {}& f(y^{(k)}, x^{(k+1)})
# \end{split}
#
#
# Keep in mind the listed observations in this overview and proceed to implementation.
# An important point is that you use different strategies for the noted terms.
#
#
# .. note::
# You can understand :math:`\{Pm, Pd\}` more thoroughly with this explanation.
# You can understand :math:`\{Pm, Pd\}` more thoroughly with this explanation.
# Roughly speaking, there is a relationship between how :math:`g` and
# :math:`lg` (the line graph) work together with loopy brief propagation.
# Here, you implement :math:`\{Pm, Pd\}` as a SciPy COO sparse matrix in the dataset,
Expand All @@ -329,21 +329,21 @@ def visualize(labels, g):
# multiplication. Write them as PyTorch tensor operations.
#
# In ``__init__``, you define the projection variables.
#
#
# ::
#
#
# self.linear_prev = nn.Linear(in_feats, out_feats)
# self.linear_deg = nn.Linear(in_feats, out_feats)
#
#
#
# In ``forward()``, :math:`\text{prev}` and :math:`\text{deg}` are the same
# as any other PyTorch tensor operations.
#
#
# ::
#
#
# prev_proj = self.linear_prev(feat_a)
# deg_proj = self.linear_deg(deg * feat_a)
#
#
# Implementing :math:`\text{radius}` as message passing in DGL
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# As discussed in GCN tutorial, you can formulate one adjacency operator as
Expand All @@ -355,14 +355,14 @@ def visualize(labels, g):
#
# In ``__init__``, define the projection variables used in each
# :math:`2^j` steps of message passing.
#
#
# ::
#
#
# self.linear_radius = nn.ModuleList(
# [nn.Linear(in_feats, out_feats) for i in range(radius)])
#
# In ``__forward__``, use following function ``aggregate_radius()`` to
# gather data from multiple hops. This can be seen in the following code.
# gather data from multiple hops. This can be seen in the following code.
# Note that the ``update_all`` is called multiple times.

# Return a list containing features gathered from multiple radius.
Expand All @@ -389,32 +389,32 @@ def aggregate_radius(radius, g, z):
# and implement :math:`\text{fuse}` as a sparse matrix multiplication.
#
# in ``__forward__``:
#
#
# ::
#
#
# fuse = self.linear_fuse(th.mm(pm_pd, feat_b))
#
# Completing :math:`f(x, y)`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# Finally, the following shows how to sum up all the terms together, pass it to skip connection, and
# batch norm.
#
#
# ::
#
# result = prev_proj + deg_proj + radius_proj + fuse
#
# Pass result to skip connection.
#
#
# Pass result to skip connection.
#
# ::
#
#
# result = th.cat([result[:, :n], F.relu(result[:, n:])], 1)
#
#
# Then pass the result to batch norm.
#
#
# ::
#
#
# result = self.bn(result) #Batch Normalization.
#
#
#
# Here is the complete code for one LGNN layer's abstraction :math:`f(x,y)`
class LGNNCore(nn.Module):
Expand Down Expand Up @@ -460,7 +460,7 @@ def forward(self, g, feat_a, feat_b, deg, pm_pd):
# Chain-up LGNN abstractions as an LGNN layer
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# To implement:
#
#
# .. math::
# \begin{split}
# x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\
Expand Down Expand Up @@ -518,7 +518,7 @@ def forward(self, g, lg, pm_pd):
# array in ``numpy.ndarray``. Generate the line graph by using this command:
#
# ::
#
#
# lg = g.line_graph(backtracking=False)
#
# Note that ``backtracking=False`` is required to correctly simulate non-backtracking
Expand Down Expand Up @@ -547,7 +547,7 @@ def sparse2th(mat):
# Create torch tensors
pmpd = sparse2th(pmpd)
label = th.from_numpy(label)

# Forward
z = model(g, lg, pmpd)

Expand Down Expand Up @@ -594,7 +594,7 @@ def sparse2th(mat):
#########################################
# Here is an animation to better understand the process. (40 epochs)
#
# .. figure:: https://i.imgur.com/KDUyE1S.gif
# .. figure:: https://i.imgur.com/KDUyE1S.gif
# :alt: lgnn-anim
#
# Batching graphs for parallelism
Expand All @@ -619,5 +619,5 @@ def collate_fn(batch):
return batched_graphs, batched_pmpds, batched_labels

######################################################################################
# You can find the complete code on Github at
# You can find the complete code on Github at
# `Community Detection with Graph Neural Networks (CDGNN) <https://github.com/dmlc/dgl/tree/master/examples/pytorch/line_graph>`_.
Loading

0 comments on commit 4d3c01d

Please sign in to comment.