Skip to content

Commit

Permalink
CopyMTL
Browse files Browse the repository at this point in the history
  • Loading branch information
WindChimeRan committed Nov 23, 2019
1 parent 010c38a commit 0ebe9c4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 55 deletions.
55 changes: 9 additions & 46 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# "Extracting Relational Facts by an End-to-End Neural Model with Copy Mechanism"
# CopyMTL: Copy Mechanism for Joint Extraction of Entities and Relations with Multi-Task Learning

PyTorch reimplementation for ACL2018 paper [copy re](http://aclweb.org/anthology/P18-1047)
accepted by AAAI-2020

This is a followup paper of "Extracting Relational Facts by an End-to-End Neural Model with Copy Mechanism" ACL2018 [CopyRE](http://aclweb.org/anthology/P18-1047)

This repo only contains CopyRE' part. MTL part is very old and messy, we are not going to release it. We suggest using [pytorch-crf](https://pytorch-crf.readthedocs.io/en/stable/) to implement the sequence labeling module for encoder. The dataset from CopyRE does not support MTL as well, because it lose the NER annotation.

Official tensorflow version [copy_re_tensorflow](https://github.com/xiangrongzeng/copy_re)

## Environment

python3

pytorch 0.4.0 -- 1.0
pytorch 0.4.0 -- 1.3.1

## Modify the Data path

Expand All @@ -27,49 +30,9 @@ NYT dataset:

## Run

`python main.py --gpu 0 --mode train --cell lstm -decoder_type one`

`python main.py --gpu 0 --mode test --cell lstm -decoder_type one`


# Difference

My MultiDecoder does not make difference with regard to the F1 score. I still cannot figure out the reason.

Official version fixes an [eos bug](https://github.com/xiangrongzeng/copy_re/commit/abe442eaee941ca588b7cd8daec0eec0faa5e8ef).
In this PyTorch reproduction, I think I have already bypassed the bug, however, there's no performance boost in WebNLG as they said.

MultiDecoder + GRU is bad. The training curve shows a significant overfitting. I don't know why.

## Result

OneDecoder + GRU

| Dataset | F1 | Precision | Recall |
| ------ | ------ | ------ | ------ |
| webnlg | 0.30 | 0.32 |0.28 |
| nyt| 0.52 | 0.55 | 0.49 |

OneDecoder + LSTM

| Dataset | F1 | Precision | Recall |
| ------ | ------ | ------ | ------ |
| webnlg | 0.28 | 0.30 | 0.26 |
| nyt| 0.54 | 0.59 | 0.50 |

**MultiDecoder + GRU**

| Dataset | F1 | Precision | Recall |
| ------ | ------ | ------ | ------ |
| webnlg | 0.28 | 0.30 | 0.26 |
| nyt | 0.45 | 0.49 | 0.41 |

MultiDecoder + LSTM
`python main.py --gpu 0 --mode train --cell lstm --decoder_type one`

| Dataset | F1 | Precision | Recall |
| ------ | ------ | ------ | ------ |
| webnlg | 0.29 | 0.31 | 0.27 |
| nyt | 0.56 | 0.60 | 0.52 |
`python main.py --gpu 0 --mode test --cell lstm --decoder_type one`



2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def train(self, evaluator: Evaluator=None) -> None:

f1, precision, recall = tester.test()

rel_f1, rel_precision, rel_recall = tester.rel_test()
# rel_f1, rel_precision, rel_recall = tester.rel_test()
# print('_' * 60)
print("triplet \t F1: %f \t P: %f \t R: %f \t" % (f1, precision, recall))
#
Expand Down
25 changes: 17 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import torch
import torch.nn.functional as F

# torch_bool
try:
torch_bool = torch.bool
except:
torch_bool = torch.uint8

class Encoder(nn.Module):
def __init__(self, config: const.Config, embedding: nn.modules.sparse.Embedding) -> None:
Expand Down Expand Up @@ -87,7 +92,9 @@ def __init__(self, config: const.Config, embedding: nn.modules.sparse.Embedding,

self.do_eos = nn.Linear(self.hidden_size, 1)
self.do_predict = nn.Linear(self.hidden_size, self.relation_number)
self.do_copy_linear = nn.Linear(self.hidden_size * 2, 1)

self.fuse = nn.Linear(self.hidden_size * 2, 100)
self.do_copy_linear = nn.Linear(100, 1)

def calc_context(self, decoder_state: torch.Tensor, encoder_outputs: torch.Tensor) -> torch.Tensor:

Expand All @@ -102,8 +109,9 @@ def calc_context(self, decoder_state: torch.Tensor, encoder_outputs: torch.Tenso
def do_copy(self, output: torch.Tensor, encoder_outputs: torch.Tensor) -> torch.Tensor:

out = torch.cat((output.unsqueeze(1).expand_as(encoder_outputs), encoder_outputs), dim=2)
# out = F.selu(self.do_copy_linear(out).squeeze(2))
out = (self.do_copy_linear(out).squeeze(2))
out = F.selu(self.fuse(F.selu(out)))
out = self.do_copy_linear(out).squeeze(2)
# out = (self.do_copy_linear(out).squeeze(2))
return out

def _decode_step(self, rnn_cell: nn.modules,
Expand Down Expand Up @@ -139,7 +147,8 @@ def _decode_step(self, rnn_cell: nn.modules,

# assert copy_logits.size() == first_entity_mask.size()
# original
copy_logits = copy_logits * first_entity_mask
# copy_logits = copy_logits * first_entity_mask
# copy_logits = copy_logits

copy_logits = torch.cat((copy_logits, eos_logits), dim=1)
copy_logits = F.log_softmax(copy_logits, dim=1)
Expand Down Expand Up @@ -223,14 +232,14 @@ def forward(self, sentence: torch.Tensor, decoder_state: torch.Tensor, encoder_o
output = self.relation_embedding(output)

else:
copy_index = torch.zeros_like(sentence).scatter_(1, max_action.unsqueeze(1), 1).to(torch.uint8)
copy_index = torch.zeros_like(sentence).scatter_(1, max_action.unsqueeze(1), 1).to(torch_bool)
output = sentence[copy_index]
output = self.word_embedding(output)

if t % 3 == 1:
first_entity_mask = torch.ones(go.size()[0], self.maxlen + 1).to(self.device)

index = torch.zeros_like(first_entity_mask).scatter_(1, max_action.unsqueeze(1), 1).to(torch.uint8)
index = torch.zeros_like(first_entity_mask).scatter_(1, max_action.unsqueeze(1), 1).to(torch_bool)

first_entity_mask[index] = 0
first_entity_mask = first_entity_mask[:, :-1]
Expand Down Expand Up @@ -283,14 +292,14 @@ def forward(self, sentence: torch.Tensor, decoder_state: torch.Tensor, encoder_o
output = self.relation_embedding(output)

else:
copy_index = torch.zeros_like(sentence).scatter_(1, max_action.unsqueeze(1), 1).to(torch.uint8)
copy_index = torch.zeros_like(sentence).scatter_(1, max_action.unsqueeze(1), 1).to(torch_bool)
output = sentence[copy_index]
output = self.word_embedding(output)

if t % 3 == 1:
first_entity_mask = torch.ones(go.size()[0], self.maxlen + 1).to(self.device)

index = torch.zeros_like(first_entity_mask).scatter_(1, max_action.unsqueeze(1), 1).to(torch.uint8)
index = torch.zeros_like(first_entity_mask).scatter_(1, max_action.unsqueeze(1), 1).to(torch_bool)

first_entity_mask[index] = 0
first_entity_mask = first_entity_mask[:, :-1]
Expand Down

0 comments on commit 0ebe9c4

Please sign in to comment.