Skip to content

Commit

Permalink
[Doc] Evaluation Tutorial for Link Prediction (dmlc#3463)
Browse files Browse the repository at this point in the history
* link prediction tutorial

* add performance tip

* Update L2_large_link_prediction.py
  • Loading branch information
BarclayII authored Nov 5, 2021
1 parent b717c8b commit c5ae54b
Showing 1 changed file with 106 additions and 7 deletions.
113 changes: 106 additions & 7 deletions tutorials/large/L2_large_link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@

sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.EdgeDataLoader(
# The following arguments are specific to NodeDataLoader.
# The following arguments are specific to EdgeDataLoader.
graph, # The graph
torch.arange(graph.number_of_edges()), # The edges to iterate over
sampler, # The neighbor sampler
Expand Down Expand Up @@ -218,15 +218,16 @@ def forward(self, g, h):


######################################################################
# Evaluating Performance (Optional)
# ---------------------------------
# Evaluating Performance with Unsupervised Learning (Optional)
# ------------------------------------------------------------
#
# There are various ways to evaluate the performance of link prediction.
# This tutorial follows the practice of `GraphSAGE
# paper <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__,
# where it treats the node embeddings learned by link prediction via
# training and evaluating a linear classifier on top of the learned node
# embeddings.
# paper <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__.
# Basically, it first trains a GNN via link prediction, and get an embedding
# for each node. Then it trains a downstream classifier on top of this
# embedding and compute the accuracy as an assessment of the embedding
# quality.
#


Expand Down Expand Up @@ -356,6 +357,104 @@ def closure():
break


######################################################################
# Evaluating Performance with Link Prediction (Optional)
# ------------------------------------------------------
#
# In practice, it is more common to evaluate the link prediction
# model to see whether it can predict new edges. There are different
# evaluation metrics such as
# `AUC <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve>`__
# or `various metrics from information retrieval <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)>`__.
# Ultimately, they require the model to predict one scalar score given
# a node pair among a set of node pairs.
#
# ``dgl.dataloading.EdgeDataLoader`` allows you to iterate over
# the edges of a new graph with the same nodes, while performing
# neighbor sampling on the original graph with ``g_sampling`` argument.
# This functionality enables convenient evaluation of a link prediction
# model.
#
# Assuming that you have the following test set with labels, where
# ``test_pos_src`` and ``test_pos_dst`` are ground truth node pairs
# with edges in between (or *positive* pairs), and ``test_neg_src``
# and ``test_neg_dst`` are ground truth node pairs without edges
# in between (or *negative* pairs).
#

# Positive pairs
test_pos_src, test_pos_dst = graph.edges()
# Negative pairs
test_neg_src = test_pos_src
test_neg_dst = torch.randint(0, graph.num_nodes(), (graph.num_edges(),))


######################################################################
# First you need to construct a graph for ``dgl.dataloading.EdgeDataLoader``
# to iterate on, i.e. with the testing node pairs as edges.
# You also need to label the edges, 1 if positive and 0 if negative.
#

test_src = torch.cat([test_pos_src, test_neg_src])
test_dst = torch.cat([test_neg_src, test_neg_dst])
test_graph = dgl.graph((test_src, test_dst), num_nodes=graph.num_nodes())
test_graph.edata['label'] = torch.cat(
[torch.ones_like(test_pos_src), torch.zeros_like(test_neg_src)])


######################################################################
# Then you could create a new ``EdgeDataLoader`` instance that
# iterates on the new ``test_graph``, but uses the original ``graph``
# for neighbor sampling.
#
# Note that you do not need negative sampling in this dataloader: the
# negative pairs are already in the new test graph.
#
test_dataloader = dgl.dataloading.EdgeDataLoader(
# The following arguments are specific to EdgeDataLoader.
test_graph, # The graph to iterate edges over
torch.arange(test_graph.number_of_edges()), # The edges to iterate over
sampler, # The neighbor sampler
device=device, # Put the MFGs on CPU or GPU
g_sampling=graph, # Graph to sample neighbors
# The following arguments are inherited from PyTorch DataLoader.
batch_size=1024, # Batch size
shuffle=True, # Whether to shuffle the nodes for every epoch
drop_last=False, # Whether to drop the last incomplete batch
num_workers=0 # Number of sampler processes
)


######################################################################
# The rest is similar to training except that you no longer compute
# the gradients, and you collect all the scores and ground truth
# labels for final metric calculation.
#
# .. note::
#
# If the graph does not change, you can also precompute all the
# node representations beforehand with ``inference`` function.
# You can then feed the precomputed results directly into the
# predictor without passing the MFGs into the model.
#
test_preds = []
test_labels = []

with tqdm.tqdm(test_dataloader) as tq, torch.no_grad():
for step, (input_nodes, pair_graph, mfgs) in enumerate(tq):
# feature copy from CPU to GPU takes place here
inputs = mfgs[0].srcdata['feat']

outputs = model(mfgs, inputs)
test_preds.append(predictor(pair_graph, outputs))
test_labels.append(pair_graph.edata['label'])

test_preds = torch.cat(test_preds).cpu().numpy()
test_labels = torch.cat(test_labels).cpu().numpy()
auc = sklearn.metrics.roc_auc_score(test_labels, test_preds)
print('Link Prediction AUC:', auc)


######################################################################
# Conclusion
# ----------
Expand Down

0 comments on commit c5ae54b

Please sign in to comment.