Skip to content

Commit

Permalink
[bugfix]修复 coreference resolution复现代码中参数名字不对应的bug (fastnlp#323)
Browse files Browse the repository at this point in the history
* pipeline

* 修复找不到对应参数的bug

* 增加requirement文件
  • Loading branch information
Xiaoxiong-Liu authored Aug 29, 2020
1 parent b9b688d commit acdebfc
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 7 deletions.
3 changes: 2 additions & 1 deletion reproduction/coreference_resolution/model/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(self):
self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)]

# TODO 改名为evaluate,输入也
def evaluate(self, predicted, mention_to_predicted,clusters):
def evaluate(self, predicted, mention_to_predicted,target):
clusters = target
for e in self.evaluators:
e.update(predicted,mention_to_predicted, clusters)

Expand Down
29 changes: 23 additions & 6 deletions reproduction/coreference_resolution/model/model_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
np.random.seed(0) # numpy
random.seed(0)


class ffnn(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(ffnn, self).__init__()
Expand Down Expand Up @@ -565,19 +564,37 @@ def forward(self, words1 , words2, words3, words4, chars, seq_len):

return ans

def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len):
def predict(self, words1 , words2, words3, words4, chars, seq_len):
"""
实际输入都是tensor
:param sentences: 句子,被fastNLP转化成了numpy,
:param doc_np: 被fastNLP转化成了Tensor
:param speaker_ids_np: 被fastNLP转化成了Tensor
:param genre: 被fastNLP转化成了Tensor
:param char_index: 被fastNLP转化成了Tensor
:param seq_len: 被fastNLP转化成了Tensor
:return:
"""

sentences = words1
doc_np = words2
speaker_ids_np = words3
genre = words4
char_index = chars

# def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len):
ans = self(sentences,
doc_np,
speaker_ids_np,
genre,
char_index,
seq_len)

predicted_antecedents = self.get_predicted_antecedents(ans["antecedents"], ans["antecedent_scores"])
predicted_clusters, mention_to_predicted = self.get_predicted_clusters(ans["mention_start_tensor"],
ans["mention_end_tensor"],
predicted_antecedents = self.get_predicted_antecedents(ans["antecedents"], ans["antecedent_scores"].cpu())
predicted_clusters, mention_to_predicted = self.get_predicted_clusters(ans["mention_start_tensor"].cpu(),
ans["mention_end_tensor"].cpu(),
predicted_antecedents)


return {'predicted':predicted_clusters,"mention_to_predicted":mention_to_predicted}


Expand Down
5 changes: 5 additions & 0 deletions reproduction/coreference_resolution/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
prettytable==0.7.2
allennlp==0.8.2
scikit-learn==0.22.2
pyhocon==0.3.50
torch==1.1
2 changes: 2 additions & 0 deletions reproduction/coreference_resolution/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys
sys.path.append('../..')

import torch
from torch.optim import Adam
Expand Down

0 comments on commit acdebfc

Please sign in to comment.