Skip to content

Commit

Permalink
renamed NER to NEC for the relation extraction metric that also consi…
Browse files Browse the repository at this point in the history
…ders entity types, print metric description
  • Loading branch information
markus-eberts committed Jan 22, 2020
1 parent 5385234 commit 38a27a8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 26 deletions.
31 changes: 19 additions & 12 deletions spert/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,24 +99,31 @@ def compute_scores(self):
print("Evaluation")

print("")
print("--- Entities (NER) ---")
print("--- Entities (named entity recognition (NER)) ---")
print("An entity is considered correct if the entity type and span is predicted correctly")
print("")
gt, pred = self._convert_by_setting(self._gt_entities, self._pred_entities, include_entity_types=True)
ner_eval = self._score(gt, pred, print_results=True)

print("")
print("--- Relations ---")
print("")
print("Without NER")
print("Without named entity classification (NEC)")
print("A relation is considered correct if the relation type and the spans of the two "
"related entities are predicted correctly (entity type is not considered)")
print("")
gt, pred = self._convert_by_setting(self._gt_relations, self._pred_relations, include_entity_types=False)
rel_eval = self._score(gt, pred, print_results=True)

print("")
print("With NER")
print("With named entity classification (NEC)")
print("A relation is considered correct if the relation type and the two "
"related entities are predicted correctly (in span and entity type)")
print("")
gt, pred = self._convert_by_setting(self._gt_relations, self._pred_relations, include_entity_types=True)
rel_ner_eval = self._score(gt, pred, print_results=True)
rel_nec_eval = self._score(gt, pred, print_results=True)

return ner_eval, rel_eval, rel_ner_eval
return ner_eval, rel_eval, rel_nec_eval

def store_examples(self):
if jinja2 is None:
Expand All @@ -125,7 +132,7 @@ def store_examples(self):

entity_examples = []
rel_examples = []
rel_examples_ner = []
rel_examples_nec = []

for i, doc in enumerate(self._dataset.documents):
# entities
Expand All @@ -140,9 +147,9 @@ def store_examples(self):
rel_examples.append(rel_example)

# with entity types
rel_example_ner = self._convert_example(doc, self._gt_relations[i], self._pred_relations[i],
rel_example_nec = self._convert_example(doc, self._gt_relations[i], self._pred_relations[i],
include_entity_types=True, to_html=self._rel_to_html)
rel_examples_ner.append(rel_example_ner)
rel_examples_nec.append(rel_example_nec)

label, epoch = self._dataset_label, self._epoch

Expand All @@ -168,13 +175,13 @@ def store_examples(self):
template='relation_examples.html')

# with entity types
self._store_examples(rel_examples_ner[:self._example_count],
file_path=self._examples_path % ('rel_ner', label, epoch),
self._store_examples(rel_examples_nec[:self._example_count],
file_path=self._examples_path % ('rel_nec', label, epoch),
template='relation_examples.html')

self._store_examples(sorted(rel_examples_ner[:self._example_count],
self._store_examples(sorted(rel_examples_nec[:self._example_count],
key=lambda k: k['length']),
file_path=self._examples_path % ('rel_ner_sorted', label, epoch),
file_path=self._examples_path % ('rel_nec_sorted', label, epoch),
template='relation_examples.html')

def _convert_gt(self, docs: List[Document]):
Expand Down
28 changes: 14 additions & 14 deletions spert/spert_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def _eval(self, model: torch.nn.Module, dataset: Dataset, input_reader: JsonInpu
evaluator.eval_batch(entity_clf, rel_clf, rels, batch)

global_iteration = epoch * updates_epoch + iteration
ner_eval, rel_eval, rel_ner_eval = evaluator.compute_scores()
self._log_eval(*ner_eval, *rel_eval, *rel_ner_eval,
ner_eval, rel_eval, rel_nec_eval = evaluator.compute_scores()
self._log_eval(*ner_eval, *rel_eval, *rel_nec_eval,
epoch, iteration, global_iteration, dataset.label)

if self.args.store_examples:
Expand Down Expand Up @@ -281,8 +281,8 @@ def _log_eval(self, ner_prec_micro: float, ner_rec_micro: float, ner_f1_micro: f
rel_prec_micro: float, rel_rec_micro: float, rel_f1_micro: float,
rel_prec_macro: float, rel_rec_macro: float, rel_f1_macro: float,

rel_ner_prec_micro: float, rel_ner_rec_micro: float, rel_ner_f1_micro: float,
rel_ner_prec_macro: float, rel_ner_rec_macro: float, rel_ner_f1_macro: float,
rel_nec_prec_micro: float, rel_nec_rec_micro: float, rel_nec_f1_micro: float,
rel_nec_prec_macro: float, rel_nec_rec_macro: float, rel_nec_f1_macro: float,
epoch: int, iteration: int, global_iteration: int, label: str):

# log to tensorboard
Expand All @@ -300,12 +300,12 @@ def _log_eval(self, ner_prec_micro: float, ner_rec_micro: float, ner_f1_micro: f
self._log_tensorboard(label, 'eval/rel_recall_macro', rel_rec_macro, global_iteration)
self._log_tensorboard(label, 'eval/rel_f1_macro', rel_f1_macro, global_iteration)

self._log_tensorboard(label, 'eval/rel_ner_prec_micro', rel_ner_prec_micro, global_iteration)
self._log_tensorboard(label, 'eval/rel_ner_recall_micro', rel_ner_rec_micro, global_iteration)
self._log_tensorboard(label, 'eval/rel_ner_f1_micro', rel_ner_f1_micro, global_iteration)
self._log_tensorboard(label, 'eval/rel_ner_prec_macro', rel_ner_prec_macro, global_iteration)
self._log_tensorboard(label, 'eval/rel_ner_recall_macro', rel_ner_rec_macro, global_iteration)
self._log_tensorboard(label, 'eval/rel_ner_f1_macro', rel_ner_f1_macro, global_iteration)
self._log_tensorboard(label, 'eval/rel_nec_prec_micro', rel_nec_prec_micro, global_iteration)
self._log_tensorboard(label, 'eval/rel_nec_recall_micro', rel_nec_rec_micro, global_iteration)
self._log_tensorboard(label, 'eval/rel_nec_f1_micro', rel_nec_f1_micro, global_iteration)
self._log_tensorboard(label, 'eval/rel_nec_prec_macro', rel_nec_prec_macro, global_iteration)
self._log_tensorboard(label, 'eval/rel_nec_recall_macro', rel_nec_rec_macro, global_iteration)
self._log_tensorboard(label, 'eval/rel_nec_f1_macro', rel_nec_f1_macro, global_iteration)

# log to csv
self._log_csv(label, 'eval', ner_prec_micro, ner_rec_micro, ner_f1_micro,
Expand All @@ -314,8 +314,8 @@ def _log_eval(self, ner_prec_micro: float, ner_rec_micro: float, ner_f1_micro: f
rel_prec_micro, rel_rec_micro, rel_f1_micro,
rel_prec_macro, rel_rec_macro, rel_f1_macro,

rel_ner_prec_micro, rel_ner_rec_micro, rel_ner_f1_micro,
rel_ner_prec_macro, rel_ner_rec_macro, rel_ner_f1_macro,
rel_nec_prec_micro, rel_nec_rec_micro, rel_nec_f1_micro,
rel_nec_prec_macro, rel_nec_rec_macro, rel_nec_f1_macro,
epoch, iteration, global_iteration)

def _log_datasets(self, input_reader):
Expand Down Expand Up @@ -350,6 +350,6 @@ def _init_eval_logging(self, label):
'ner_prec_macro', 'ner_rec_macro', 'ner_f1_macro',
'rel_prec_micro', 'rel_rec_micro', 'rel_f1_micro',
'rel_prec_macro', 'rel_rec_macro', 'rel_f1_macro',
'rel_ner_prec_micro', 'rel_ner_rec_micro', 'rel_ner_f1_micro',
'rel_ner_prec_macro', 'rel_ner_rec_macro', 'rel_ner_f1_macro',
'rel_nec_prec_micro', 'rel_nec_rec_micro', 'rel_nec_f1_micro',
'rel_nec_prec_macro', 'rel_nec_rec_macro', 'rel_nec_f1_macro',
'epoch', 'iteration', 'global_iteration']})

0 comments on commit 38a27a8

Please sign in to comment.