Skip to content

Commit

Permalink
[Sampler] fix the API of neighbor sampler. (dmlc#407)
Browse files Browse the repository at this point in the history
* don't return aux_info.

* fix sampler test.

* fix sse.

* fix.

* add comment.
  • Loading branch information
zheng-da authored Feb 28, 2019
1 parent 7e30382 commit bea07b4
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 156 deletions.
22 changes: 11 additions & 11 deletions examples/mxnet/sampling/gcn_cv_sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,12 @@ def main(args):
# initialize graph
dur = []
for epoch in range(args.n_epochs):
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
num_neighbors,
neighbor_type='in',
shuffle=True,
num_hops=n_layers,
seed_nodes=train_nid):
for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
num_neighbors,
neighbor_type='in',
shuffle=True,
num_hops=n_layers,
seed_nodes=train_nid):
for i in range(n_layers):
agg_history_str = 'agg_h_{}'.format(i)
g.pull(nf.layer_parent_nid(i+1), fn.copy_src(src='h_{}'.format(i), out='m'),
Expand Down Expand Up @@ -270,11 +270,11 @@ def main(args):

num_acc = 0.

for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_hops=n_layers,
seed_nodes=test_nid):
for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_hops=n_layers,
seed_nodes=test_nid):
node_embed_names = [['preprocess']]
for i in range(n_layers):
node_embed_names.append(['norm'])
Expand Down
22 changes: 11 additions & 11 deletions examples/mxnet/sampling/gcn_ns_sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,12 @@ def main(args):
# initialize graph
dur = []
for epoch in range(args.n_epochs):
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors,
neighbor_type='in',
shuffle=True,
num_hops=args.n_layers+1,
seed_nodes=train_nid):
for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors,
neighbor_type='in',
shuffle=True,
num_hops=args.n_layers+1,
seed_nodes=train_nid):
nf.copy_from_parent()
# forward
with mx.autograd.record():
Expand All @@ -215,11 +215,11 @@ def main(args):

num_acc = 0.

for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_hops=args.n_layers+1,
seed_nodes=test_nid):
for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_hops=args.n_layers+1,
seed_nodes=test_nid):
nf.copy_from_parent()
pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1).astype('int64').as_in_context(ctx)
Expand Down
26 changes: 13 additions & 13 deletions examples/mxnet/sampling/graphsage_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,13 @@ def main(args):
# initialize graph
dur = []
for epoch in range(args.n_epochs):
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
num_neighbors,
neighbor_type='in',
shuffle=True,
num_hops=n_layers,
add_self_loop=True,
seed_nodes=train_nid):
for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
num_neighbors,
neighbor_type='in',
shuffle=True,
num_hops=n_layers,
add_self_loop=True,
seed_nodes=train_nid):
for i in range(n_layers):
agg_history_str = 'agg_h_{}'.format(i)
g.pull(nf.layer_parent_nid(i+1), fn.copy_src(src='h_{}'.format(i), out='m'),
Expand Down Expand Up @@ -314,12 +314,12 @@ def main(args):

num_acc = 0.

for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_hops=n_layers,
seed_nodes=test_nid,
add_self_loop=True):
for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_hops=n_layers,
seed_nodes=test_nid,
add_self_loop=True):
node_embed_names = [['preprocess', 'features']]
for i in range(n_layers):
node_embed_names.append(['norm', 'subg_norm'])
Expand Down
144 changes: 65 additions & 79 deletions examples/mxnet/sse/sse_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from dgl.data import register_data_args, load_data

def gcn_msg(edges):
# TODO should we use concat?
return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)}

def gcn_reduce(nodes):
Expand All @@ -26,7 +25,6 @@ class NodeUpdate(gluon.Block):
def __init__(self, out_feats, activation=None, alpha=0.1, **kwargs):
super(NodeUpdate, self).__init__(**kwargs)
self.linear1 = gluon.nn.Dense(out_feats, activation=activation)
# TODO what is the dimension here?
self.linear2 = gluon.nn.Dense(out_feats)
self.alpha = alpha

Expand All @@ -43,56 +41,15 @@ def __init__(self, update):
def forward(self, node):
return {'h1': self.update(node.data['in'], node.data['h'], node.data['accum'])}

class SSEUpdateHidden(gluon.Block):
def __init__(self,
n_hidden,
dropout,
activation,
**kwargs):
super(SSEUpdateHidden, self).__init__(**kwargs)
with self.name_scope():
self.layer = NodeUpdate(n_hidden, activation)
self.dropout = dropout
self.n_hidden = n_hidden

def forward(self, g, vertices):
if vertices is None:
deg = mx.nd.expand_dims(g.in_degrees(), 1).astype(np.float32)
feat = g.get_n_repr()['in']
cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
accum = mx.nd.dot(g.adjacency_matrix(), cat) / deg
batch_size = 100000
num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
ret = mx.nd.empty(shape=(feat.shape[0], self.n_hidden), ctx=feat.context)
for i in range(num_batches):
vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
ret[vs] = self.layer(mx.nd.take(feat, vs),
mx.nd.take(g.ndata['h'], vs),
mx.nd.take(accum, vs))
return ret
else:
deg = mx.nd.expand_dims(g.in_degrees(vertices), 1).astype(np.float32)
# We don't need dropout for inference.
if self.dropout:
# TODO here we apply dropout on all vertex representation.
g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
feat = g.get_n_repr()['in']
cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
slices = mx.nd.take(g.adjacency_matrix(), vertices).as_in_context(cat.context)
accum = mx.nd.dot(slices, cat) / deg.as_in_context(cat.context)
vertices = vertices.as_in_context(g.ndata['in'].context)
return self.layer(mx.nd.take(feat, vertices),
mx.nd.take(g.ndata['h'], vertices), accum)

class DGLSSEUpdateHidden(gluon.Block):
class DGLSSEUpdateHiddenInfer(gluon.Block):
def __init__(self,
n_hidden,
activation,
dropout,
use_spmv,
inference,
**kwargs):
super(DGLSSEUpdateHidden, self).__init__(**kwargs)
super(DGLSSEUpdateHiddenInfer, self).__init__(**kwargs)
with self.name_scope():
self.layer = DGLNodeUpdate(NodeUpdate(n_hidden, activation))
self.dropout = dropout
Expand Down Expand Up @@ -125,7 +82,6 @@ def forward(self, g, vertices):
else:
# We don't need dropout for inference.
if self.dropout:
# TODO here we apply dropout on all vertex representation.
g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
g.update_all(msg_func, reduce_func, None)
ctx = g.ndata['accum'].context
Expand All @@ -137,6 +93,47 @@ def forward(self, g, vertices):
g.ndata.pop('accum')
return mx.nd.take(g.ndata['h1'], vertices.as_in_context(ctx))

class DGLSSEUpdateHiddenTrain(gluon.Block):
def __init__(self,
n_hidden,
activation,
dropout,
use_spmv,
inference,
**kwargs):
super(DGLSSEUpdateHiddenTrain, self).__init__(**kwargs)
with self.name_scope():
self.update = DGLNodeUpdate(NodeUpdate(n_hidden, activation))
self.dropout = dropout
self.use_spmv = use_spmv
self.inference = inference

def forward(self, subg, vertices):
assert vertices is not None
if self.use_spmv:
feat = subg.layers[0].data['in']
subg.layers[0].data['cat'] = mx.nd.concat(feat, subg.layers[0].data['h'],
dim=1)

msg_func = fn.copy_src(src='cat', out='m')
reduce_func = fn.sum(msg='m', out='accum')
else:
msg_func = gcn_msg
reduce_func = gcn_reduce
deg = mx.nd.expand_dims(subg.layer_in_degree(1), 1).astype(np.float32)
# We don't need dropout for inference.
if self.dropout:
subg.layers[0].data['h'] = mx.nd.Dropout(subg.layers[0].data['h'], p=self.dropout)
subg.block_compute(0, msg_func, reduce_func, None)
ctx = subg.layers[1].data['accum'].context
if self.use_spmv:
subg.layers[0].data.pop('cat')
deg = deg.as_in_context(ctx)
subg.layers[1].data['accum'] = subg.layers[1].data['accum'] / deg
subg.apply_layer(1, self.update, inplace=self.inference)
subg.layers[1].data.pop('accum')
return subg.layers[1].data['h1']

class SSEPredict(gluon.Block):
def __init__(self, update_hidden, out_feats, dropout, **kwargs):
super(SSEPredict, self).__init__(**kwargs)
Expand All @@ -153,17 +150,10 @@ def forward(self, g, vertices):
return self.linear2(self.linear1(hidden))

def copy_to_gpu(subg, ctx):
frame = subg.ndata
for key in frame:
subg.ndata[key] = frame[key].as_in_context(ctx)

class CachedSubgraph(object):
def __init__(self, subg, seeds):
# We can't cache the input subgraph because it contains node frames
# and data frames.
self.subg = dgl.DGLSubGraph(subg._parent, subg._parent_nid, subg._parent_eid,
subg._graph)
self.seeds = seeds
for i in range(subg.num_layers):
frame = subg.layers[i].data
for key in frame:
subg.layers[i].data[key] = frame[key].as_in_context(ctx)

class CachedSubgraphLoader(object):
def __init__(self, loader, shuffle):
Expand All @@ -182,14 +172,17 @@ def __iter__(self):

def __next__(self):
if len(self._subgraphs) > 0:
s = self._subgraphs.pop(0)
subg, seeds = s.subg, s.seeds
subg = self._subgraphs.pop(0)
elif self._gen_subgraph:
subg, seeds = self._loader.__next__()
subg = self._loader.__next__()
else:
raise StopIteration
self._cached.append(CachedSubgraph(subg, seeds))
return subg, seeds

# We can't cache the input subgraph because it contains node frames
# and data frames.
subg = dgl.NodeFlow(subg._parent, subg._graph)
self._cached.append(subg)
return subg

def main(args, data):
if isinstance(data.features, mx.nd.NDArray):
Expand Down Expand Up @@ -224,17 +217,12 @@ def main(args, data):
g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden),
ctx=mx.cpu(0))

update_hidden_infer = DGLSSEUpdateHidden(args.n_hidden, 'relu',
args.update_dropout, args.use_spmv,
inference=True, prefix='sse')
update_hidden_train = DGLSSEUpdateHidden(args.n_hidden, 'relu',
args.update_dropout, args.use_spmv,
inference=False, prefix='sse')
if not args.dgl:
update_hidden_infer = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
prefix='sse')
update_hidden_train = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
prefix='sse')
update_hidden_infer = DGLSSEUpdateHiddenInfer(args.n_hidden, 'relu',
args.update_dropout, args.use_spmv,
inference=True, prefix='sse')
update_hidden_train = DGLSSEUpdateHiddenTrain(args.n_hidden, 'relu',
args.update_dropout, args.use_spmv,
inference=False, prefix='sse')

model_train = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app')
model_infer = SSEPredict(update_hidden_infer, args.n_hidden, args.predict_dropout, prefix='app')
Expand Down Expand Up @@ -277,9 +265,9 @@ def main(args, data):
i = 0
num_batches = len(train_vs) / args.batch_size
start1 = time.time()
for subg, aux_infos in sampler:
seeds = aux_infos['seeds']
subg_seeds = subg.layer_nid(0)
for subg in sampler:
seeds = subg.layer_parent_nid(-1)
subg_seeds = subg.layer_nid(-1)
subg.copy_from_parent()

losses = []
Expand Down Expand Up @@ -316,8 +304,7 @@ def main(args, data):
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in',
num_workers=args.num_parallel_subgraphs,
seed_nodes=train_vs, shuffle=True,
return_seed_id=True)
seed_nodes=train_vs, shuffle=True)

# test set accuracy
logits = model_infer(g, eval_vs)
Expand Down Expand Up @@ -394,7 +381,6 @@ def __init__(self, csr, num_feats):
help="the percentage of data used for training")
parser.add_argument("--use-spmv", action="store_true",
help="use SpMV for faster speed.")
parser.add_argument("--dgl", action="store_true")
parser.add_argument("--cache-subgraph", default=False, action="store_false")
parser.add_argument("--num-parallel-subgraphs", type=int, default=1,
help="the number of subgraphs to construct in parallel.")
Expand Down
Loading

0 comments on commit bea07b4

Please sign in to comment.