forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
toy.py
69 lines (53 loc) · 2.2 KB
/
toy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
###############################################################################
# A toy example
# -------------
#
# Let’s begin with the simplest graph possible with two nodes, and set
# the node representations:
import torch as th
import dgl
g = dgl.DGLGraph()
g.add_nodes(2)
g.add_edge(1, 0)
x = th.tensor([[0.0, 0.0], [1.0, 2.0]])
g.nodes[:].data['x'] = x
###############################################################################
# A syntax sugar for accessing feature data of all nodes
print(g.ndata['x'])
###############################################################################
# What we want to do is simply to copy representation from node#1 to
# node#0, but with a message passing interface. We do this like what we
# will do over a pair of sockets, with a send and a recv interface. The
# two user defined function (UDF) specifies the actions: deposit the
# value into an internal key-value store with the key msg, and retrive
# it. Note that there may be multiple incoming edges to a node, and the
# receiving end aggregates them.
def send_source(edges): # type is dgl.EdgeBatch
return {'msg': edges.src['x']}
def simple_reduce(nodes): # type is dgl.NodeBatch
msgs = nodes.mailbox['msg']
return {'x' : th.sum(msgs, dim=1)}
g.send((1, 0), message_func=send_source)
g.recv(0, reduce_func=simple_reduce)
print(g.ndata)
###############################################################################
# Some times the computation may involve representations on the edges.
# Let’s say we want to “amplify” the message:
w = th.tensor([2.0])
g.edata['w'] = w
def send_source_with_edge_weight(edges):
return {'msg': edges.src['x'] * edges.data['w']}
g.send((1, 0), message_func=send_source_with_edge_weight)
g.recv(0, reduce_func=simple_reduce)
print(g.ndata)
###############################################################################
# Or we may need to involve the desination’s representation, and here
# is one version:
def simple_reduce_addup(nodes):
msgs = nodes.mailbox['msg']
return {'x' : nodes.data['x'] + th.sum(msgs, dim=1)}
g.send((1, 0), message_func=send_source_with_edge_weight)
g.recv(0, reduce_func=simple_reduce_addup)
print(g.ndata)
del g.ndata['x']
del g.edata['w']