Skip to content

Commit

Permalink
quickly integrating with tree-lstm example
Browse files Browse the repository at this point in the history
  • Loading branch information
jermainewang committed Sep 25, 2018
1 parent 882e2a7 commit a8a4fcb
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 194 deletions.
38 changes: 18 additions & 20 deletions examples/pytorch/tree_lstm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,17 @@

import dgl
import dgl.data as data
import dgl.ndarray as nd

from tree_lstm import TreeLSTM

def _batch_to_cuda(batch):
return data.SSTBatch(graph=batch.graph,
nid_with_word = batch.nid_with_word.cuda(),
wordid = batch.wordid.cuda(),
label = batch.label.cuda())

import dgl.context as ctx
def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes()
if cuda:
adjmat = g.cached_graph.adjmat().get(ctx.gpu(args.gpu))
adjmat = g._graph.adjacency_matrix().get(nd.gpu(args.gpu))
mask = th.ones((n, 1)).cuda()
else:
adjmat = g.cached_graph.adjmat().get(ctx.cpu())
adjmat = g._graph.adjacency_matrix().get(nd.cpu())
mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.:
Expand All @@ -39,10 +33,17 @@ def main(args):
cuda = args.gpu >= 0
if cuda:
th.cuda.set_device(args.gpu)
def _batcher(trees):
bg = dgl.batch(trees)
if cuda:
reprs = bg.get_n_repr()
reprs = {key : reprs[key].cuda()}
bg.set_n_repr(reprs)
return bg
trainset = data.SST()
train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size,
collate_fn=data.SST.batcher,
collate_fn=_batcher,
shuffle=False,
num_workers=0)
#testset = data.SST(mode='test')
Expand All @@ -69,18 +70,15 @@ def main(args):
dur = []
for epoch in range(args.epochs):
t_epoch = time.time()
for step, batch in enumerate(train_loader):
g = batch.graph
if cuda:
batch = _batch_to_cuda(batch)

for step, graph in enumerate(train_loader):
if step >= 3:
t0 = time.time()
label = graph.pop_n_repr('y')
# traverse graph
giter = list(tensor_topo_traverse(g, False, args))
logits = model(batch, zero_initializer, iterator=giter, train=True)
giter = list(tensor_topo_traverse(graph, False, args))
logits = model(graph, zero_initializer, iterator=giter, train=True)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label)
loss = F.nll_loss(logp, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Expand All @@ -89,11 +87,11 @@ def main(args):

if step > 0 and step % args.log_every == 0:
pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred))
acc = th.sum(th.eq(label, pred))
mean_dur = np.mean(dur)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
epoch, step, loss.item(), acc.item()/len(batch.label),
epoch, step, loss.item(), acc.item() / args.batch_size,
mean_dur, args.batch_size / mean_dur))
print("Epoch time(s):", time.time() - t_epoch)

Expand Down
34 changes: 10 additions & 24 deletions examples/pytorch/tree_lstm/tree_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F

def topological_traverse(G):
indegree_map = {v: d for v, d in G.in_degree() if d > 0}
# These nodes have zero indegree and ready to be returned.
zero_indegree = [v for v, d in G.in_degree() if d == 0]
while True:
yield zero_indegree
next_zero_indegree = []
while zero_indegree:
node = zero_indegree.pop()
for _, child in G.edges(node):
indegree_map[child] -= 1
if indegree_map[child] == 0:
next_zero_indegree.append(child)
del indegree_map[child]
if len(next_zero_indegree) == 0:
break
zero_indegree = next_zero_indegree
import dgl

class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
Expand Down Expand Up @@ -83,13 +67,13 @@ def __init__(self,
else:
raise RuntimeError('Unknown cell type:', cell_type)

def forward(self, batch, zero_initializer, h=None, c=None, iterator=None, train=True):
def forward(self, graph, zero_initializer, h=None, c=None, iterator=None, train=True):
"""Compute tree-lstm prediction given a batch.
Parameters
----------
batch : dgl.data.SSTBatch
The data batch.
graph : dgl.DGLGraph
The batched trees.
zero_initializer : callable
Function to return zero value tensor.
h : Tensor, optional
Expand All @@ -104,15 +88,17 @@ def forward(self, batch, zero_initializer, h=None, c=None, iterator=None, train=
logits : Tensor
The prediction of each node.
"""
g = batch.graph
g = graph
n = g.number_of_nodes()
g.register_message_func(self.cell.message_func, batchable=True)
g.register_reduce_func(self.cell.reduce_func, batchable=True)
g.register_apply_node_func(self.cell.apply_func, batchable=True)
# feed embedding
embeds = self.embedding(batch.wordid)
x = zero_initializer((n, self.x_size))
x = x.index_copy(0, batch.nid_with_word, embeds)
wordid = g.pop_n_repr('x')
mask = (wordid != dgl.data.SST.PAD_WORD)
wordid = wordid * mask.long()
embeds = self.embedding(wordid)
x = embeds * th.unsqueeze(mask, 1).float()
if h is None:
h = zero_initializer((n, self.h_size))
h_tild = zero_initializer((n, self.h_size))
Expand Down
6 changes: 6 additions & 0 deletions include/dgl/graph_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ namespace dgl {

class GraphOp {
public:
/*!
* \brief Return the line graph.
* \param graph The input graph.
* \return the line graph
*/
static Graph LineGraph(const Graph* graph);
/*!
* \brief Return a disjoint union of the input graphs.
*
Expand Down
147 changes: 0 additions & 147 deletions python/dgl/cached_graph.py

This file was deleted.

6 changes: 4 additions & 2 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self,
# msg graph & frame
self._msg_graph = create_graph_index()
self._msg_frame = FrameRef()
self.reset_messages()
# registered functions
self._message_func = (None, None)
self._reduce_func = (None, None)
Expand Down Expand Up @@ -112,7 +113,7 @@ def clear(self):
self._msg_graph.clear()
self._msg_frame.clear()

def clear_messages(self):
def reset_messages(self):
"""Clear all messages."""
self._msg_graph.clear()
self._msg_frame.clear()
Expand Down Expand Up @@ -447,6 +448,7 @@ def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
self.clear()
self._graph.from_networkx(nx_graph)
self._msg_graph.add_nodes(self._graph.number_of_nodes())
# copy attributes
def _batcher(lst):
if isinstance(lst[0], Tensor):
return F.pack([F.unsqueeze(x, 0) for x in lst])
Expand Down Expand Up @@ -1078,7 +1080,7 @@ def _reshape_fn(msg):
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))

# TODO: clear partial messages
self.clear_messages()
self.reset_messages()

# Pack all reducer results together
reordered_v = F.pack(reordered_v)
Expand Down
2 changes: 1 addition & 1 deletion src/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void Graph::AddVertices(uint64_t num_vertices) {
void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(HasVertex(src) && HasVertex(dst))
<< "In valid vertices: " << src << " " << dst;
<< "Invalid vertices: src=" << src << " dst=" << dst;
dgl_id_t eid = num_edges_++;
adjlist_[src].succ.push_back(dst);
adjlist_[src].edge_id.push_back(eid);
Expand Down

0 comments on commit a8a4fcb

Please sign in to comment.