Skip to content

Commit

Permalink
[Doc] Patch tutorial (dmlc#1380)
Browse files Browse the repository at this point in the history
* patched 1_first

* done 2_basics

* done 4_batch

* done 1_gcn, 9_gat, 2_capsule

* 4_rgcn.py

* revert

* more fix
  • Loading branch information
jermainewang authored Mar 22, 2020
1 parent 0f40c6e commit 2cdc4d3
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 243 deletions.
127 changes: 45 additions & 82 deletions tutorials/basics/1_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,26 @@
# Create the graph for Zachary's karate club as follows:

import dgl
import numpy as np

def build_karate_club_graph():
g = dgl.DGLGraph()
# add 34 nodes into the graph; nodes are labeled from 0~33
g.add_nodes(34)
# all 78 edges as a list of tuples
edge_list = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2),
(4, 0), (5, 0), (6, 0), (6, 4), (6, 5), (7, 0), (7, 1),
(7, 2), (7, 3), (8, 0), (8, 2), (9, 2), (10, 0), (10, 4),
(10, 5), (11, 0), (12, 0), (12, 3), (13, 0), (13, 1), (13, 2),
(13, 3), (16, 5), (16, 6), (17, 0), (17, 1), (19, 0), (19, 1),
(21, 0), (21, 1), (25, 23), (25, 24), (27, 2), (27, 23),
(27, 24), (28, 2), (29, 23), (29, 26), (30, 1), (30, 8),
(31, 0), (31, 24), (31, 25), (31, 28), (32, 2), (32, 8),
(32, 14), (32, 15), (32, 18), (32, 20), (32, 22), (32, 23),
(32, 29), (32, 30), (32, 31), (33, 8), (33, 9), (33, 13),
(33, 14), (33, 15), (33, 18), (33, 19), (33, 20), (33, 22),
(33, 23), (33, 26), (33, 27), (33, 28), (33, 29), (33, 30),
(33, 31), (33, 32)]
# add edges two lists of nodes: src and dst
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst)
# edges are directional in DGL; make them bi-directional
g.add_edges(dst, src)

return g
# All 78 edges are stored in two numpy arrays. One for source endpoints
# while the other for destination endpoints.
src = np.array([1, 2, 2, 3, 3, 3, 4, 5, 6, 6, 6, 7, 7, 7, 7, 8, 8, 9, 10, 10,
10, 11, 12, 12, 13, 13, 13, 13, 16, 16, 17, 17, 19, 19, 21, 21,
25, 25, 27, 27, 27, 28, 29, 29, 30, 30, 31, 31, 31, 31, 32, 32,
32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33,
33, 33, 33, 33, 33, 33, 33, 33, 33, 33])
dst = np.array([0, 0, 1, 0, 1, 2, 0, 0, 0, 4, 5, 0, 1, 2, 3, 0, 2, 2, 0, 4,
5, 0, 0, 3, 0, 1, 2, 3, 5, 6, 0, 1, 0, 1, 0, 1, 23, 24, 2, 23,
24, 2, 23, 26, 1, 8, 0, 24, 25, 28, 2, 8, 14, 15, 18, 20, 22, 23,
29, 30, 31, 8, 9, 13, 14, 15, 18, 19, 20, 22, 23, 26, 27, 28, 29, 30,
31, 32])
# Edges are directional in DGL; Make them bi-directional.
u = np.concatenate([src, dst])
v = np.concatenate([dst, src])
# Construct a DGLGraph
return dgl.DGLGraph((u, v))

###############################################################################
# Print out the number of nodes and edges in our newly constructed graph:
Expand All @@ -95,27 +89,28 @@ def build_karate_club_graph():
# Step 2: Assign features to nodes or edges
# --------------------------------------------
# Graph neural networks associate features with nodes and edges for training.
# For our classification example, we assign each node an input feature as a one-hot vector:
# node :math:`v_i`'s feature vector is :math:`[0,\ldots,1,\dots,0]`,
# where the :math:`i^{th}` position is one.
#
# For our classification example, since there is no input feature, we assign each node
# with a learnable embedding vector.

# In DGL, you can add features for all nodes at once, using a feature tensor that
# batches node features along the first dimension. The code below adds the one-hot
# feature for all nodes:
# batches node features along the first dimension. The code below adds the learnable
# embeddings for all nodes:

import torch
import torch.nn as nn
import torch.nn.functional as F

G.ndata['feat'] = torch.eye(34)

embed = nn.Embedding(34, 5) # 34 nodes with embedding dim equal to 5
G.ndata['feat'] = embed.weight

###############################################################################
# Print out the node features to verify:

# print out node 2's input feature
print(G.nodes[2].data['feat'])
print(G.ndata['feat'][2])

# print out node 10 and 11's input features
print(G.nodes[[10, 11]].data['feat'])
print(G.ndata['feat'][[10, 11]])

###############################################################################
# Step 3: Define a Graph Convolutional Network (GCN)
Expand All @@ -139,74 +134,41 @@ def build_karate_club_graph():
# :alt: mailbox
# :align: center
#
# Now, we show that the GCN layer can be easily implemented in DGL.

import torch.nn as nn
import torch.nn.functional as F

# Define the message and reduce function
# NOTE: We ignore the GCN's normalization constant c_ij for this tutorial.
def gcn_message(edges):
# The argument is a batch of edges.
# This computes a (batch of) message called 'msg' using the source node's feature 'h'.
return {'msg' : edges.src['h']}
# In DGL, we provide implementations of popular Graph Neural Network layers under
# the `dgl.<backend>.nn` subpackage. The :class:`~dgl.nn.pytorch.GraphConv` module
# implements one Graph Convolutional layer.

def gcn_reduce(nodes):
# The argument is a batch of nodes.
# This computes the new 'h' features by summing received 'msg' in each node's mailbox.
return {'h' : torch.sum(nodes.mailbox['msg'], dim=1)}

# Define the GCNLayer module
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)

def forward(self, g, inputs):
# g is the graph and the inputs is the input node features
# first set the node features
g.ndata['h'] = inputs
# trigger message passing on all edges
g.send(g.edges(), gcn_message)
# trigger aggregation at all nodes
g.recv(g.nodes(), gcn_reduce)
# get the result node features
h = g.ndata.pop('h')
# perform linear transformation
return self.linear(h)
from dgl.nn.pytorch import GraphConv

###############################################################################
# In general, the nodes send information computed via the *message functions*,
# and aggregate incoming information with the *reduce functions*.
#
# Define a deeper GCN model that contains two GCN layers:

# Define a 2-layer GCN model
class GCN(nn.Module):
def __init__(self, in_feats, hidden_size, num_classes):
super(GCN, self).__init__()
self.gcn1 = GCNLayer(in_feats, hidden_size)
self.gcn2 = GCNLayer(hidden_size, num_classes)
self.conv1 = GraphConv(in_feats, hidden_size)
self.conv2 = GraphConv(hidden_size, num_classes)

def forward(self, g, inputs):
h = self.gcn1(g, inputs)
h = self.conv1(g, inputs)
h = torch.relu(h)
h = self.gcn2(g, h)
h = self.conv2(g, h)
return h
# The first layer transforms input features of size of 34 to a hidden size of 5.

# The first layer transforms input features of size of 5 to a hidden size of 5.
# The second layer transforms the hidden layer and produces output features of
# size 2, corresponding to the two groups of the karate club.
net = GCN(34, 5, 2)
net = GCN(5, 5, 2)

###############################################################################
# Step 4: Data preparation and initialization
# -------------------------------------------
#
# We use one-hot vectors to initialize the node features. Since this is a
# We use learnable embeddings to initialize the node features. Since this is a
# semi-supervised setting, only the instructor (node 0) and the club president
# (node 33) are assigned labels. The implementation is available as follow.

inputs = torch.eye(34)
inputs = embed.weight
labeled_nodes = torch.tensor([0, 33]) # only the instructor and the president nodes are labeled
labels = torch.tensor([0, 1]) # their labels are different

Expand All @@ -216,10 +178,11 @@ def forward(self, g, inputs):
# The training loop is exactly the same as other PyTorch models.
# We (1) create an optimizer, (2) feed the inputs to the model,
# (3) calculate the loss and (4) use autograd to optimize the model.
import itertools

optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
optimizer = torch.optim.Adam(itertools.chain(net.parameters(), embed.parameters()), lr=0.01)
all_logits = []
for epoch in range(30):
for epoch in range(50):
logits = net(G, inputs)
# we save the logits for visualization later
all_logits.append(logits.detach())
Expand Down
85 changes: 55 additions & 30 deletions tutorials/basics/2_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,50 @@


###############################################################################
# The examples here show the same graph, except that :class:`DGLGraph` is always directional.
# There are many ways to construct a :class:`DGLGraph`. Below are the allowed
# data types ordered by our recommendataion.
#
# You can also create a graph by calling the DGL interface.
#
# In the next example, you build a star graph. :class:`DGLGraph` nodes are a consecutive range of
# integers between 0 and :func:`number_of_nodes() <DGLGraph.number_of_nodes>`
# and can grow by calling :func:`add_nodes <DGLGraph.add_nodes>`.
# * A pair of arrays ``(u, v)`` storing the source and destination nodes respectively.
# They can be numpy arrays or tensor objects from the backend framework.
# * ``scipy`` sparse matrix representing the adjacency matrix of the graph to be
# constructed.
# * ``networkx`` graph object.
# * A list of edges in the form of integer pairs.
#
# The examples below construct the same star graph via different methods.
#
# :class:`DGLGraph` nodes are a consecutive range of integers between 0 and
# :func:`number_of_nodes() <DGLGraph.number_of_nodes>`.
# :class:`DGLGraph` edges are in order of their additions. Note that
# edges are accessed in much the same way as nodes, with one extra feature: *edge broadcasting*.
# edges are accessed in much the same way as nodes, with one extra feature:
# *edge broadcasting*.

import dgl
import torch as th
import numpy as np
import scipy.sparse as spp

# Create a star graph from a pair of arrays (using ``numpy.array`` works too).
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
star1 = dgl.DGLGraph((u, v))

# Create the same graph in one go! Essentially, if one of the arrays is a scalar,
# the value is automatically broadcasted to match the length of the other array
# -- a feature called *edge broadcasting*.
start2 = dgl.DGLGraph((0, v))

# Create the same graph from a scipy sparse matrix (using ``scipy.sparse.csr_matrix`` works too).
adj = spp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
star3 = dgl.DGLGraph(adj)

# Create the same graph from a list of integer pairs.
elist = [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5)]
star4 = dgl.DGLGraph(elist)

###############################################################################
# You can also create a graph by progressively adding more nodes and edges.
# Although it is not as efficient as the above constructors, it is suitable
# for applications where the graph cannot be constructed in one shot.

g = dgl.DGLGraph()
g.add_nodes(10)
Expand All @@ -63,12 +95,10 @@
src = th.tensor(list(range(1, 10)));
g.add_edges(src, 0)

import networkx as nx
import matplotlib.pyplot as plt
# Visualize the graph.
nx.draw(g.to_networkx(), with_labels=True)
plt.show()


###############################################################################
# Assigning a feature
# -------------------
Expand All @@ -89,19 +119,14 @@
x = th.randn(10, 3)
g.ndata['x'] = x


###############################################################################
# :func:`ndata <DGLGraph.ndata>` is a syntax sugar to access the state of all nodes.
# States are stored
# in a container ``data`` that hosts a user-defined dictionary.

print(g.ndata['x'] == g.nodes[:].data['x'])

# Access node set with integer, list, or integer tensor
g.nodes[0].data['x'] = th.zeros(1, 3)
g.nodes[[0, 1, 2]].data['x'] = th.zeros(3, 3)
g.nodes[th.tensor([0, 1, 2])].data['x'] = th.zeros(3, 3)
# :func:`ndata <DGLGraph.ndata>` is a syntax sugar to access the feature
# data of all nodes. To get the features of some particular nodes, slice out
# the corresponding rows.

g.ndata['x'][0] = th.zeros(1, 3)
g.ndata['x'][[0, 1, 2]] = th.zeros(3, 3)
g.ndata['x'][th.tensor([0, 1, 2])] = th.randn((3, 3))

###############################################################################
# Assigning edge features is similar to that of node features,
Expand All @@ -110,14 +135,15 @@
g.edata['w'] = th.randn(9, 2)

# Access edge set with IDs in integer, list, or integer tensor
g.edges[1].data['w'] = th.randn(1, 2)
g.edges[[0, 1, 2]].data['w'] = th.zeros(3, 2)
g.edges[th.tensor([0, 1, 2])].data['w'] = th.zeros(3, 2)

# You can also access the edges by giving endpoints
g.edges[1, 0].data['w'] = th.ones(1, 2) # edge 1 -> 0
g.edges[[1, 2, 3], [0, 0, 0]].data['w'] = th.ones(3, 2) # edges [1, 2, 3] -> 0
g.edata['w'][1] = th.randn(1, 2)
g.edata['w'][[0, 1, 2]] = th.zeros(3, 2)
g.edata['w'][th.tensor([0, 1, 2])] = th.zeros(3, 2)

# You can get the edge ids by giving endpoints, which are useful for accessing the features.
g.edata['w'][g.edge_id(1, 0)] = th.ones(1, 2) # edge 1 -> 0
g.edata['w'][g.edge_ids([1, 2, 3], [0, 0, 0])] = th.ones(3, 2) # edges [1, 2, 3] -> 0
# Use edge broadcasting whenever applicable.
g.edata['w'][g.edge_ids([1, 2, 3], 0)] = th.ones(3, 2) # edges [1, 2, 3] -> 0

###############################################################################
# After assignments, each node or edge field will be associated with a scheme
Expand Down Expand Up @@ -170,7 +196,6 @@
# * Updating a feature of different schemes raises the risk of error on individual nodes (or
# node subset).


###############################################################################
# Next steps
# ----------
Expand Down
Loading

0 comments on commit 2cdc4d3

Please sign in to comment.