Skip to content

Commit

Permalink
add option to remove overlapping entities and relations with overlapp…
Browse files Browse the repository at this point in the history
…ing entities during evaluation
  • Loading branch information
markus-eberts committed Apr 7, 2020
1 parent 2538ab4 commit 7ca6278
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 7 deletions.
2 changes: 2 additions & 0 deletions args.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def _add_common_args(arg_parser):
arg_parser.add_argument('--size_embedding', type=int, default=25, help="Dimensionality of size embedding")
arg_parser.add_argument('--prop_drop', type=float, default=0.1, help="Probability of dropout used in SpERT")
arg_parser.add_argument('--freeze_transformer', action='store_true', default=False, help="Freeze BERT weights")
arg_parser.add_argument('--no_overlapping', action='store_true', default=False,
help="Do not evaluate on overlapping entities and relations with overlapping entities")

# Misc
arg_parser.add_argument('--seed', type=int, default=None, help="Seed")
Expand Down
48 changes: 43 additions & 5 deletions spert/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

class Evaluator:
def __init__(self, dataset: Dataset, input_reader: JsonInputReader, text_encoder: BertTokenizer,
rel_filter_threshold: float, predictions_path: str, examples_path: str, example_count: int,
epoch: int, dataset_label: str):
rel_filter_threshold: float, no_overlapping: bool,
predictions_path: str, examples_path: str, example_count: int, epoch: int, dataset_label: str):
self._text_encoder = text_encoder
self._input_reader = input_reader
self._dataset = dataset
self._rel_filter_threshold = rel_filter_threshold
self._no_overlapping = no_overlapping

self._epoch = epoch
self._dataset_label = dataset_label
Expand Down Expand Up @@ -85,7 +86,6 @@ def eval_batch(self, batch_entity_clf: torch.tensor, batch_rel_clf: torch.tensor
# convert predicted relations for evaluation
sample_pred_relations = self._convert_pred_relations(rel_types, rel_entity_spans,
rel_entity_types, rel_scores)
self._pred_relations.append(sample_pred_relations)

# get entities that are not classified as 'None'
valid_entity_indices = entity_types.nonzero().view(-1)
Expand All @@ -96,7 +96,13 @@ def eval_batch(self, batch_entity_clf: torch.tensor, batch_rel_clf: torch.tensor

sample_pred_entities = self._convert_pred_entities(valid_entity_types, valid_entity_spans,
valid_entity_scores)

if self._no_overlapping:
sample_pred_entities, sample_pred_relations = self._remove_overlapping(sample_pred_entities,
sample_pred_relations)

self._pred_entities.append(sample_pred_entities)
self._pred_relations.append(sample_pred_relations)

def compute_scores(self):
print("Evaluation")
Expand Down Expand Up @@ -242,11 +248,15 @@ def _convert_gt(self, docs: List[Document]):
gt_entities = doc.entities

# convert ground truth relations and entities for precision/recall/f1 evaluation
sample_gt_relations = [rel.as_tuple() for rel in gt_relations]
sample_gt_entities = [entity.as_tuple() for entity in gt_entities]
sample_gt_relations = [rel.as_tuple() for rel in gt_relations]

if self._no_overlapping:
sample_gt_entities, sample_gt_relations = self._remove_overlapping(sample_gt_entities,
sample_gt_relations)

self._gt_relations.append(sample_gt_relations)
self._gt_entities.append(sample_gt_entities)
self._gt_relations.append(sample_gt_relations)

def _convert_pred_entities(self, pred_types: torch.tensor, pred_spans: torch.tensor, pred_scores: torch.tensor):
converted_preds = []
Expand Down Expand Up @@ -290,6 +300,34 @@ def _convert_pred_relations(self, pred_rel_types: torch.tensor, pred_entity_span

return converted_rels

def _remove_overlapping(self, entities, relations):
non_overlapping_entities = []
non_overlapping_relations = []

for entity in entities:
if not self._is_overlapping(entity, entities):
non_overlapping_entities.append(entity)

for rel in relations:
e1, e2 = rel[0], rel[1]
if not self._check_overlap(e1, e2):
non_overlapping_relations.append(rel)

return non_overlapping_entities, non_overlapping_relations

def _is_overlapping(self, e1, entities):
for e2 in entities:
if self._check_overlap(e1, e2):
return True

return False

def _check_overlap(self, e1, e2):
if e1 == e2 or e1[1] <= e2[0] or e2[1] <= e1[0]:
return False
else:
return True

def _adjust_rel(self, rel: Tuple):
adjusted_rel = rel
if rel[-1].symmetric:
Expand Down
4 changes: 2 additions & 2 deletions spert/spert_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _eval(self, model: torch.nn.Module, dataset: Dataset, input_reader: JsonInpu

# create evaluator
evaluator = Evaluator(dataset, input_reader, self._tokenizer,
self.args.rel_filter_threshold, self._predictions_path,
self.args.rel_filter_threshold, self.args.no_overlapping, self._predictions_path,
self._examples_path, self.args.example_count, epoch, dataset.label)

# create data loader
Expand Down Expand Up @@ -250,7 +250,7 @@ def _eval(self, model: torch.nn.Module, dataset: Dataset, input_reader: JsonInpu
self._log_eval(*ner_eval, *rel_eval, *rel_nec_eval,
epoch, iteration, global_iteration, dataset.label)

if self.args.store_predictions:
if self.args.store_predictions and not self.args.no_overlapping:
evaluator.store_predictions()

if self.args.store_examples:
Expand Down

0 comments on commit 7ca6278

Please sign in to comment.