Skip to content

Commit

Permalink
Fix batched graph edge order bug and other fixes (dmlc#50)
Browse files Browse the repository at this point in the history
* fix dgl.batch edge ordering bug

* add graph batching test cases

* fix partial spmv ctx.

* add dataset generating for dgmg
  • Loading branch information
lingfanyu authored Aug 23, 2018
1 parent 6105e44 commit c42eac7
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 13 deletions.
9 changes: 8 additions & 1 deletion examples/pytorch/generative_graph/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.functional as F
import numpy as np
import argparse
from util import DataLoader, elapsed
from util import DataLoader, elapsed, generate_dataset
import time

class MLP(nn.Module):
Expand Down Expand Up @@ -246,9 +246,16 @@ def masked_cross_entropy(x, label, mask=None):
help="number of hidden gcn layers")
parser.add_argument("--dataset", type=str, default='samples.p',
help="dataset pickle file")
parser.add_argument("--gen-dataset", type=str, default=None,
help="parameters to generate B-A graph datasets. Format: <#node>,<#edge>,<#sample>")
parser.add_argument("--batch-size", type=int, default=32,
help="batch size")
args = parser.parse_args()
print(args)

# generate dataset if needed
if args.gen_dataset is not None:
n_node, n_edge, n_sample = map(int, args.gen_dataset.split(','))
generate_dataset(n_node, n_edge, n_sample, args.dataset)

main(args)
13 changes: 7 additions & 6 deletions examples/pytorch/generative_graph/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@ def convert_graph_to_ordering(g):
ordering.append((m, n))
return ordering

def generate_dataset():
n = 15
m = 2
n_samples = 1024
def generate_dataset(n, m, n_samples, fname):
samples = []
for _ in range(n_samples):
g = nx.barabasi_albert_graph(n, m)
samples.append(convert_graph_to_ordering(g))

with open('samples.p', 'wb') as f:
with open(fname, 'wb') as f:
pickle.dump(samples, f)

class DataLoader(object):
Expand Down Expand Up @@ -153,4 +150,8 @@ def elapsed(msg, start, end):
print("{}: {} ms".format(msg, int((end-start)*1000)))

if __name__ == '__main__':
generate_dataset()
n = 15
m = 2
n_samples = 1024
fname ='samples.p'
generate_dataset(n, m, n_samples, fname)
2 changes: 1 addition & 1 deletion python/dgl/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, graph_list, node_attrs=None, edge_attrs=None, **attr):
self.add_nodes_from(range(self.node_offset[-1]))

# in-order add relabeled edges
self.new_edge_list = [np.array(g.edges) + offset
self.new_edge_list = [np.array(g.edge_list) + offset
for g, offset in zip(self.graph_list, self.node_offset[:-1])]
self.new_edges = np.concatenate(self.new_edge_list)
self.add_edges_from(self.new_edges)
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,11 +710,12 @@ def _batch_update_by_edge(
m = len(new2old)
# TODO(minjie): context
adjmat = F.sparse_tensor(idx, dat, [m, n])
ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx))
# TODO(minjie): use lazy dict for reduced_msgs
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
reduced_msgs[key] = F.spmm(adjmat, col)
reduced_msgs[key] = F.spmm(ctx_adjmat.get(F.get_context(col)), col)
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
node_repr = self.get_n_repr(new2old)
Expand Down
31 changes: 27 additions & 4 deletions tests/pytorch/test_graph_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def tree1():
g.add_edge(1, 0)
g.add_edge(2, 0)
g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4]))
g.set_e_repr(torch.randn(4, 10))
return g

def tree2():
Expand All @@ -45,19 +46,24 @@ def tree2():
g.add_edge(4, 1)
g.add_edge(3, 1)
g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4]))
g.set_e_repr(torch.randn(4, 10))
return g

def test_batch_unbatch():
t1 = tree1()
t2 = tree2()
f1 = t1.get_n_repr()
f2 = t2.get_n_repr()
n1 = t1.get_n_repr()
n2 = t2.get_n_repr()
e1 = t1.get_e_repr()
e2 = t2.get_e_repr()

bg = dgl.batch([t1, t2])
dgl.unbatch(bg)

assert(f1.equal(t1.get_n_repr()))
assert(f2.equal(t2.get_n_repr()))
assert(n1.equal(t1.get_n_repr()))
assert(n2.equal(t2.get_n_repr()))
assert(e1.equal(t1.get_e_repr()))
assert(e2.equal(t2.get_e_repr()))


def test_batch_sendrecv():
Expand Down Expand Up @@ -120,8 +126,25 @@ def test_batch_propagate():
assert t1.get_n_repr()[0] == 9
assert t2.get_n_repr()[1] == 5

def test_batched_edge_ordering():
g1 = dgl.DGLGraph()
g1.add_nodes_from([0,1,2, 3, 4, 5])
g1.add_edges_from([(4, 5), (4, 3), (2, 3), (2, 1), (0, 1)])
g1.edge_list
e1 = torch.randn(5, 10)
g1.set_e_repr(e1)
g2 = dgl.DGLGraph()
g2.add_nodes_from([0, 1, 2, 3, 4, 5])
g2.add_edges_from([(0, 1), (1, 2), (2, 3), (5, 4), (4, 3), (5, 0)])
e2 = torch.randn(6, 10)
g2.set_e_repr(e2)
g = dgl.batch([g1, g2])
r1 = g.get_e_repr()[g.get_edge_id(4, 5)]
r2 = g1.get_e_repr()[g1.get_edge_id(4, 5)]
assert torch.equal(r1, r2)

if __name__ == '__main__':
test_batch_unbatch()
test_batched_edge_ordering()
test_batch_sendrecv()
test_batch_propagate()

0 comments on commit c42eac7

Please sign in to comment.