Skip to content

Commit

Permalink
[Example] add graphwriter pytorch example (dmlc#1068)
Browse files Browse the repository at this point in the history
* upd

* fig edgebatch edges

* add test

* trigger

* add graphwriter pytorch example

* fix line break in graphwriter README

* upd

* fix
  • Loading branch information
QipengGuo authored and yzh119 committed Dec 4, 2019
1 parent 35653dd commit fff3dd9
Show file tree
Hide file tree
Showing 11 changed files with 913 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ A summary of part of the model accuracy and training speed with the Pytorch back
| [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x |
| [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a |
| [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a |
| [GraphWriter](https://www.aclweb.org/anthology/N19-1238.pdf) | 14.3(BLEU) | 14.31(BLEU) | [1970s (PyTorch)](https://github.com/rikdz/GraphWriter) | 1192s | 1.65x |

With the MXNet/Gluon backend , we scaled a graph of 50M nodes and 150M edges on a P3.8xlarge instance,
with 160s per epoch, on SSE ([Stochastic Steady-state Embedding](https://www.cc.gatech.edu/~hdai8/pdf/equilibrium_embedding.pdf)),
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ Here is a summary of the model accuracy and training speed. Our testbed is Amazo
| [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x |
| [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a |
| [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a |
| [GraphWriter](https://www.aclweb.org/anthology/N19-1238.pdf) | 14.31(BLEU) | 14.3(BLEU) | 1970s | 1192s | 1.65x |
44 changes: 44 additions & 0 deletions examples/pytorch/graphwriter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# GraphWriter-DGL
In this example we implement the GraphWriter, [Text Generation from Knowledge Graphs with Graph Transformers](https://arxiv.org/abs/1904.02342) in DGL. And the [author's code](https://github.com/rikdz/GraphWriter).

## Dependencies
- PyTorch >= 1.2
- tqdm
- pycoco (only for testing)
- multi-bleu.perl and other scripts from mosesdecoder (only for testing)

## Usage
```
# download data
sh prepare_data.sh
# training
sh run.sh
# testing
sh test.sh
```

## Result on AGENDA
| |BLEU|METEOR| training time per epoch|
|-|-|-|-|
|Author's implementation|14.3+-1.01| 18.8+-0.28| 1970s|
|DGL implementation|14.31+-0.34|19.74+-0.69| 1192s|

We use the author's code for the speed test, and our testbed is V100 GPU.

| |BLEU| detok BLEU| METEOR |
|-|-|-|-|
|greedy, two layers| 13.97 +- 0.40| 13.78 +- 0.46| 18.76 +- 0.36|
|beam 4, length penalty 1.0, two layers| 14.66 +- 0.65| 14.53 +- 0.52| 19.50 +- 0.49|
|beam 4, length penalty 0.0, two layers| 14.33 +- 0.39| 14.09 +- 0.39| 18.63 +- 0.52|
|greedy, six layers| 14.17 +- 0.46| 14.01 +- 0.51| 19.18 +- 0.49|
|beam 4, length penalty 1.0, six layers| 14.31 +- 0.34| 14.35 +- 0.36| 19.74 +- 0.69|
|beam 4, length penalty 0.0, six layers| 14.40 +- 0.85| 14.15 +- 0.84| 18.86 +- 0.78|

We repeat the experiment five times.

### Examples

We also provide the output of our implementation on test set together with the reference text.
- [GraphWriter's output](https://s3.us-east-2.amazonaws.com/dgl.ai/models/graphwriter/tmp_pred.txt)
- [Reference text](https://s3.us-east-2.amazonaws.com/dgl.ai/models/graphwriter/tmp_gold.txt)

182 changes: 182 additions & 0 deletions examples/pytorch/graphwriter/graphwriter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import torch
from modules import MSA, BiLSTM, GraphTrans
from utlis import *
from torch import nn
import dgl


class GraphWriter(nn.Module):
def __init__(self, args):
super(GraphWriter, self).__init__()
self.args = args
if args.title:
self.title_emb = nn.Embedding(len(args.title_vocab), args.nhid, padding_idx=0)
self.title_enc = BiLSTM(args, enc_type='title')
self.title_attn = MSA(args)
self.ent_emb = nn.Embedding(len(args.ent_text_vocab), args.nhid, padding_idx=0)
self.tar_emb = nn.Embedding(len(args.text_vocab), args.nhid, padding_idx=0)
if args.title:
nn.init.xavier_normal_(self.title_emb.weight)
nn.init.xavier_normal_(self.ent_emb.weight)
self.rel_emb = nn.Embedding(len(args.rel_vocab), args.nhid, padding_idx=0)
nn.init.xavier_normal_(self.rel_emb.weight)
self.decode_lstm = nn.LSTMCell(args.dec_ninp, args.nhid)
self.ent_enc = BiLSTM(args, enc_type='entity')
self.graph_enc = GraphTrans(args)
self.ent_attn = MSA(args)
self.copy_attn = MSA(args, mode='copy')
self.copy_fc = nn.Linear(args.dec_ninp, 1)
self.pred_v_fc = nn.Linear(args.dec_ninp, len(args.text_vocab))

def enc_forward(self, batch, ent_mask, ent_text_mask, ent_len, rel_mask, title_mask):
title_enc = None
if self.args.title:
title_enc = self.title_enc(self.title_emb(batch['title']), title_mask)
ent_enc = self.ent_enc(self.ent_emb(batch['ent_text']), ent_text_mask, ent_len = batch['ent_len'])
rel_emb = self.rel_emb(batch['rel'])
g_ent, g_root = self.graph_enc(ent_enc, ent_mask, ent_len, rel_emb, rel_mask, batch['graph'])
return g_ent, g_root, title_enc, ent_enc

def forward(self, batch, beam_size=-1):
ent_mask = len2mask(batch['ent_len'], self.args.device)
ent_text_mask = batch['ent_text']==0
rel_mask = batch['rel']==0 # 0 means the <PAD>
title_mask = batch['title']==0
g_ent, g_root, title_enc, ent_enc = self.enc_forward(batch, ent_mask, ent_text_mask, batch['ent_len'], rel_mask, title_mask)

_h, _c = g_root, g_root.clone().detach()
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
ctx = torch.cat([ctx, attn], 1)
if beam_size<1:
# training
outs = []
tar_inp = self.tar_emb(batch['text'].transpose(0,1))
for t, xt in enumerate(tar_inp):
_xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c))
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
ctx = torch.cat([ctx, attn], 1)
outs.append(torch.cat([_h, ctx], 1))
outs = torch.stack(outs, 1)
copy_gate = torch.sigmoid(self.copy_fc(outs))
EPSI = 1e-6
# copy
pred_v = torch.log(copy_gate+EPSI) + torch.log_softmax(self.pred_v_fc(outs), -1)
pred_c = torch.log((1. - copy_gate)+EPSI) + torch.log_softmax(self.copy_attn(outs, ent_enc, mask=ent_mask), -1)
pred = torch.cat([pred_v, pred_c], -1)
return pred
else:
if beam_size==1:
# greedy
device = g_ent.device
B = g_ent.shape[0]
ent_type = batch['ent_type'].view(B, -1)
seq = (torch.ones(B,).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(1)
for t in range(self.args.beam_max_len):
_inp = replace_ent(seq[:,-1], ent_type, len(self.args.text_vocab))
xt = self.tar_emb(_inp)
_xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c))
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
ctx = torch.cat([ctx, attn], 1)
_y = torch.cat([_h, ctx], 1)
copy_gate = torch.sigmoid(self.copy_fc(_y))
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1)
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1)
pred = torch.cat([pred_v, pred_c], -1).view(B,-1)
for ban_item in ['<BOS>', '<PAD>', '<UNK>']:
pred[:, self.args.text_vocab(ban_item)] = -1e8
_, word = pred.max(-1)
seq = torch.cat([seq, word.unsqueeze(1)], 1)
return seq
else:
# beam search
device = g_ent.device
B = g_ent.shape[0]
BSZ = B * beam_size
_h = _h.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
_c = _c.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
ent_mask = ent_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
if self.args.title:
title_mask = title_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
title_enc = title_enc.view(B, 1, title_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, title_enc.size(1), -1)
ctx = ctx.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
ent_type = batch['ent_type'].view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
g_ent = g_ent.view(B, 1, g_ent.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, g_ent.size(1), -1)
ent_enc = ent_enc.view(B, 1, ent_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, ent_enc.size(1), -1)

beam_best = torch.zeros(B).to(device) - 1e9
beam_best_seq = [None] * B
beam_seq = (torch.ones(B, beam_size).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(-1)
beam_score = torch.zeros(B, beam_size).to(device)
done_flag = torch.zeros(B, beam_size)
for t in range(self.args.beam_max_len):
_inp = replace_ent(beam_seq[:,:,-1].view(-1), ent_type, len(self.args.text_vocab))
xt = self.tar_emb(_inp)
_xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c))
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
ctx = torch.cat([ctx, attn], 1)
_y = torch.cat([_h, ctx], 1)
copy_gate = torch.sigmoid(self.copy_fc(_y))
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1)
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1)
pred = torch.cat([pred_v, pred_c], -1).view(B, beam_size, -1)
for ban_item in ['<BOS>', '<PAD>', '<UNK>']:
pred[:, :, self.args.text_vocab(ban_item)] = -1e8
if t==self.args.beam_max_len-1: # force ending
tt = pred[:, :, self.args.text_vocab('<EOS>')]
pred = pred*0-1e8
pred[:, :, self.args.text_vocab('<EOS>')] = tt
cum_score = beam_score.view(B,beam_size,1) + pred
score, word = cum_score.topk(dim=-1, k=beam_size) # B, beam_size, beam_size
score, word = score.view(B,-1), word.view(B,-1)
eos_idx = self.args.text_vocab('<EOS>')
if beam_seq.size(2)==1:
new_idx = torch.arange(beam_size).to(word)
new_idx = new_idx[None,:].repeat(B,1)
else:
_, new_idx = score.topk(dim=-1, k=beam_size)
new_src, new_score, new_word, new_done = [], [], [], []
LP = beam_seq.size(2) ** self.args.lp
for i in range(B):
for j in range(beam_size):
tmp_score = score[i][new_idx[i][j]]
tmp_word = word[i][new_idx[i][j]]
src_idx = new_idx[i][j]//beam_size
new_src.append(src_idx)
if tmp_word == eos_idx:
new_score.append(-1e8)
else:
new_score.append(tmp_score)
new_word.append(tmp_word)

if tmp_word == eos_idx and done_flag[i][src_idx]==0 and tmp_score/LP>beam_best[i]:
beam_best[i] = tmp_score/LP
beam_best_seq[i] = beam_seq[i][src_idx]
if tmp_word == eos_idx:
new_done.append(1)
else:
new_done.append(done_flag[i][src_idx])
new_score = torch.Tensor(new_score).view(B,beam_size).to(beam_score)
new_word = torch.Tensor(new_word).view(B,beam_size).to(beam_seq)
new_src = torch.LongTensor(new_src).view(B,beam_size).to(device)
new_done = torch.Tensor(new_done).view(B,beam_size).to(done_flag)
beam_score = new_score
done_flag = new_done
beam_seq = beam_seq.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src]
beam_seq = torch.cat([beam_seq, new_word.unsqueeze(2)], 2)
_h = _h.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)
_c = _c.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)
ctx = ctx.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)

return beam_best_seq

Loading

0 comments on commit fff3dd9

Please sign in to comment.