-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] add graphwriter pytorch example (dmlc#1068)
* upd * fig edgebatch edges * add test * trigger * add graphwriter pytorch example * fix line break in graphwriter README * upd * fix
- Loading branch information
Showing
11 changed files
with
913 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# GraphWriter-DGL | ||
In this example we implement the GraphWriter, [Text Generation from Knowledge Graphs with Graph Transformers](https://arxiv.org/abs/1904.02342) in DGL. And the [author's code](https://github.com/rikdz/GraphWriter). | ||
|
||
## Dependencies | ||
- PyTorch >= 1.2 | ||
- tqdm | ||
- pycoco (only for testing) | ||
- multi-bleu.perl and other scripts from mosesdecoder (only for testing) | ||
|
||
## Usage | ||
``` | ||
# download data | ||
sh prepare_data.sh | ||
# training | ||
sh run.sh | ||
# testing | ||
sh test.sh | ||
``` | ||
|
||
## Result on AGENDA | ||
| |BLEU|METEOR| training time per epoch| | ||
|-|-|-|-| | ||
|Author's implementation|14.3+-1.01| 18.8+-0.28| 1970s| | ||
|DGL implementation|14.31+-0.34|19.74+-0.69| 1192s| | ||
|
||
We use the author's code for the speed test, and our testbed is V100 GPU. | ||
|
||
| |BLEU| detok BLEU| METEOR | | ||
|-|-|-|-| | ||
|greedy, two layers| 13.97 +- 0.40| 13.78 +- 0.46| 18.76 +- 0.36| | ||
|beam 4, length penalty 1.0, two layers| 14.66 +- 0.65| 14.53 +- 0.52| 19.50 +- 0.49| | ||
|beam 4, length penalty 0.0, two layers| 14.33 +- 0.39| 14.09 +- 0.39| 18.63 +- 0.52| | ||
|greedy, six layers| 14.17 +- 0.46| 14.01 +- 0.51| 19.18 +- 0.49| | ||
|beam 4, length penalty 1.0, six layers| 14.31 +- 0.34| 14.35 +- 0.36| 19.74 +- 0.69| | ||
|beam 4, length penalty 0.0, six layers| 14.40 +- 0.85| 14.15 +- 0.84| 18.86 +- 0.78| | ||
|
||
We repeat the experiment five times. | ||
|
||
### Examples | ||
|
||
We also provide the output of our implementation on test set together with the reference text. | ||
- [GraphWriter's output](https://s3.us-east-2.amazonaws.com/dgl.ai/models/graphwriter/tmp_pred.txt) | ||
- [Reference text](https://s3.us-east-2.amazonaws.com/dgl.ai/models/graphwriter/tmp_gold.txt) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import torch | ||
from modules import MSA, BiLSTM, GraphTrans | ||
from utlis import * | ||
from torch import nn | ||
import dgl | ||
|
||
|
||
class GraphWriter(nn.Module): | ||
def __init__(self, args): | ||
super(GraphWriter, self).__init__() | ||
self.args = args | ||
if args.title: | ||
self.title_emb = nn.Embedding(len(args.title_vocab), args.nhid, padding_idx=0) | ||
self.title_enc = BiLSTM(args, enc_type='title') | ||
self.title_attn = MSA(args) | ||
self.ent_emb = nn.Embedding(len(args.ent_text_vocab), args.nhid, padding_idx=0) | ||
self.tar_emb = nn.Embedding(len(args.text_vocab), args.nhid, padding_idx=0) | ||
if args.title: | ||
nn.init.xavier_normal_(self.title_emb.weight) | ||
nn.init.xavier_normal_(self.ent_emb.weight) | ||
self.rel_emb = nn.Embedding(len(args.rel_vocab), args.nhid, padding_idx=0) | ||
nn.init.xavier_normal_(self.rel_emb.weight) | ||
self.decode_lstm = nn.LSTMCell(args.dec_ninp, args.nhid) | ||
self.ent_enc = BiLSTM(args, enc_type='entity') | ||
self.graph_enc = GraphTrans(args) | ||
self.ent_attn = MSA(args) | ||
self.copy_attn = MSA(args, mode='copy') | ||
self.copy_fc = nn.Linear(args.dec_ninp, 1) | ||
self.pred_v_fc = nn.Linear(args.dec_ninp, len(args.text_vocab)) | ||
|
||
def enc_forward(self, batch, ent_mask, ent_text_mask, ent_len, rel_mask, title_mask): | ||
title_enc = None | ||
if self.args.title: | ||
title_enc = self.title_enc(self.title_emb(batch['title']), title_mask) | ||
ent_enc = self.ent_enc(self.ent_emb(batch['ent_text']), ent_text_mask, ent_len = batch['ent_len']) | ||
rel_emb = self.rel_emb(batch['rel']) | ||
g_ent, g_root = self.graph_enc(ent_enc, ent_mask, ent_len, rel_emb, rel_mask, batch['graph']) | ||
return g_ent, g_root, title_enc, ent_enc | ||
|
||
def forward(self, batch, beam_size=-1): | ||
ent_mask = len2mask(batch['ent_len'], self.args.device) | ||
ent_text_mask = batch['ent_text']==0 | ||
rel_mask = batch['rel']==0 # 0 means the <PAD> | ||
title_mask = batch['title']==0 | ||
g_ent, g_root, title_enc, ent_enc = self.enc_forward(batch, ent_mask, ent_text_mask, batch['ent_len'], rel_mask, title_mask) | ||
|
||
_h, _c = g_root, g_root.clone().detach() | ||
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask) | ||
if self.args.title: | ||
attn = _h + self.title_attn(_h, title_enc, mask=title_mask) | ||
ctx = torch.cat([ctx, attn], 1) | ||
if beam_size<1: | ||
# training | ||
outs = [] | ||
tar_inp = self.tar_emb(batch['text'].transpose(0,1)) | ||
for t, xt in enumerate(tar_inp): | ||
_xt = torch.cat([ctx, xt], 1) | ||
_h, _c = self.decode_lstm(_xt, (_h, _c)) | ||
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask) | ||
if self.args.title: | ||
attn = _h + self.title_attn(_h, title_enc, mask=title_mask) | ||
ctx = torch.cat([ctx, attn], 1) | ||
outs.append(torch.cat([_h, ctx], 1)) | ||
outs = torch.stack(outs, 1) | ||
copy_gate = torch.sigmoid(self.copy_fc(outs)) | ||
EPSI = 1e-6 | ||
# copy | ||
pred_v = torch.log(copy_gate+EPSI) + torch.log_softmax(self.pred_v_fc(outs), -1) | ||
pred_c = torch.log((1. - copy_gate)+EPSI) + torch.log_softmax(self.copy_attn(outs, ent_enc, mask=ent_mask), -1) | ||
pred = torch.cat([pred_v, pred_c], -1) | ||
return pred | ||
else: | ||
if beam_size==1: | ||
# greedy | ||
device = g_ent.device | ||
B = g_ent.shape[0] | ||
ent_type = batch['ent_type'].view(B, -1) | ||
seq = (torch.ones(B,).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(1) | ||
for t in range(self.args.beam_max_len): | ||
_inp = replace_ent(seq[:,-1], ent_type, len(self.args.text_vocab)) | ||
xt = self.tar_emb(_inp) | ||
_xt = torch.cat([ctx, xt], 1) | ||
_h, _c = self.decode_lstm(_xt, (_h, _c)) | ||
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask) | ||
if self.args.title: | ||
attn = _h + self.title_attn(_h, title_enc, mask=title_mask) | ||
ctx = torch.cat([ctx, attn], 1) | ||
_y = torch.cat([_h, ctx], 1) | ||
copy_gate = torch.sigmoid(self.copy_fc(_y)) | ||
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1) | ||
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1) | ||
pred = torch.cat([pred_v, pred_c], -1).view(B,-1) | ||
for ban_item in ['<BOS>', '<PAD>', '<UNK>']: | ||
pred[:, self.args.text_vocab(ban_item)] = -1e8 | ||
_, word = pred.max(-1) | ||
seq = torch.cat([seq, word.unsqueeze(1)], 1) | ||
return seq | ||
else: | ||
# beam search | ||
device = g_ent.device | ||
B = g_ent.shape[0] | ||
BSZ = B * beam_size | ||
_h = _h.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) | ||
_c = _c.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) | ||
ent_mask = ent_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) | ||
if self.args.title: | ||
title_mask = title_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) | ||
title_enc = title_enc.view(B, 1, title_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, title_enc.size(1), -1) | ||
ctx = ctx.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) | ||
ent_type = batch['ent_type'].view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) | ||
g_ent = g_ent.view(B, 1, g_ent.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, g_ent.size(1), -1) | ||
ent_enc = ent_enc.view(B, 1, ent_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, ent_enc.size(1), -1) | ||
|
||
beam_best = torch.zeros(B).to(device) - 1e9 | ||
beam_best_seq = [None] * B | ||
beam_seq = (torch.ones(B, beam_size).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(-1) | ||
beam_score = torch.zeros(B, beam_size).to(device) | ||
done_flag = torch.zeros(B, beam_size) | ||
for t in range(self.args.beam_max_len): | ||
_inp = replace_ent(beam_seq[:,:,-1].view(-1), ent_type, len(self.args.text_vocab)) | ||
xt = self.tar_emb(_inp) | ||
_xt = torch.cat([ctx, xt], 1) | ||
_h, _c = self.decode_lstm(_xt, (_h, _c)) | ||
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask) | ||
if self.args.title: | ||
attn = _h + self.title_attn(_h, title_enc, mask=title_mask) | ||
ctx = torch.cat([ctx, attn], 1) | ||
_y = torch.cat([_h, ctx], 1) | ||
copy_gate = torch.sigmoid(self.copy_fc(_y)) | ||
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1) | ||
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1) | ||
pred = torch.cat([pred_v, pred_c], -1).view(B, beam_size, -1) | ||
for ban_item in ['<BOS>', '<PAD>', '<UNK>']: | ||
pred[:, :, self.args.text_vocab(ban_item)] = -1e8 | ||
if t==self.args.beam_max_len-1: # force ending | ||
tt = pred[:, :, self.args.text_vocab('<EOS>')] | ||
pred = pred*0-1e8 | ||
pred[:, :, self.args.text_vocab('<EOS>')] = tt | ||
cum_score = beam_score.view(B,beam_size,1) + pred | ||
score, word = cum_score.topk(dim=-1, k=beam_size) # B, beam_size, beam_size | ||
score, word = score.view(B,-1), word.view(B,-1) | ||
eos_idx = self.args.text_vocab('<EOS>') | ||
if beam_seq.size(2)==1: | ||
new_idx = torch.arange(beam_size).to(word) | ||
new_idx = new_idx[None,:].repeat(B,1) | ||
else: | ||
_, new_idx = score.topk(dim=-1, k=beam_size) | ||
new_src, new_score, new_word, new_done = [], [], [], [] | ||
LP = beam_seq.size(2) ** self.args.lp | ||
for i in range(B): | ||
for j in range(beam_size): | ||
tmp_score = score[i][new_idx[i][j]] | ||
tmp_word = word[i][new_idx[i][j]] | ||
src_idx = new_idx[i][j]//beam_size | ||
new_src.append(src_idx) | ||
if tmp_word == eos_idx: | ||
new_score.append(-1e8) | ||
else: | ||
new_score.append(tmp_score) | ||
new_word.append(tmp_word) | ||
|
||
if tmp_word == eos_idx and done_flag[i][src_idx]==0 and tmp_score/LP>beam_best[i]: | ||
beam_best[i] = tmp_score/LP | ||
beam_best_seq[i] = beam_seq[i][src_idx] | ||
if tmp_word == eos_idx: | ||
new_done.append(1) | ||
else: | ||
new_done.append(done_flag[i][src_idx]) | ||
new_score = torch.Tensor(new_score).view(B,beam_size).to(beam_score) | ||
new_word = torch.Tensor(new_word).view(B,beam_size).to(beam_seq) | ||
new_src = torch.LongTensor(new_src).view(B,beam_size).to(device) | ||
new_done = torch.Tensor(new_done).view(B,beam_size).to(done_flag) | ||
beam_score = new_score | ||
done_flag = new_done | ||
beam_seq = beam_seq.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src] | ||
beam_seq = torch.cat([beam_seq, new_word.unsqueeze(2)], 2) | ||
_h = _h.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1) | ||
_c = _c.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1) | ||
ctx = ctx.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1) | ||
|
||
return beam_best_seq | ||
|
Oops, something went wrong.