Skip to content

Commit

Permalink
Hypernym filtering (#52)
Browse files Browse the repository at this point in the history
* ✨ Add ability to filter hypernyms from CDR

* ✅ Add tests for hypernym filtering

* ⬆️ Upgrade dependencies

* ♻️ Call LRU cache with default arguments
  • Loading branch information
JohnGiorgi authored Feb 15, 2022
1 parent 2127c50 commit 96aa592
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 33 deletions.
56 changes: 28 additions & 28 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 17 additions & 1 deletion seq2rel_ds/common/schemas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import json
import random
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

from pydantic import BaseModel
from seq2rel_ds.common import sorting_utils, special_tokens
Expand Down Expand Up @@ -56,6 +56,7 @@ class label.
pmid: str
clusters: Dict[str, PubtatorCluster] = {}
relations: List[Tuple[str, ...]] = []
filtered_relations: Optional[List[Tuple[str, ...]]] = None

def insert_hints(self, sort: bool = True) -> None:
"""Inserts entity hints into the beginning of `self.text`. This effectively turns the
Expand Down Expand Up @@ -91,6 +92,18 @@ def to_string(self, sort: bool = True) -> str:
relation_strings.append(relation_string)
relation_offsets.append(entity_offsets)

if self.filtered_relations is not None:
if self.filtered_relations:
filtered_relation_strings = []
for rel in self.filtered_relations:
entity_strings = [self.clusters[ent_id].to_string() for ent_id in rel[:-1]]
relation_string = sanitize_text(
f'{" ".join(entity_strings)} @{rel[-1].upper()}@'
)
filtered_relation_strings.append(relation_string)
else:
filtered_relation_strings = ["null"]

# Optionally, sort by order of first appearance.
# This exists mainly for ablation, so we randomly shuffle relations if sort=False.
if relation_strings:
Expand All @@ -113,6 +126,9 @@ def to_string(self, sort: bool = True) -> str:
relation_strings = list(dict.fromkeys(relation_strings))
# Create the linearized relation string
relation_string = " ".join(relation_strings).strip()
# Possibly add the relations to filter to this string.
if self.filtered_relations is not None:
relation_string = f'{relation_string}\t{" ".join(filtered_relation_strings).strip()}'
return relation_string


Expand Down
99 changes: 95 additions & 4 deletions seq2rel_ds/preprocess/bc5cdr.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
import itertools
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import requests
import typer
from seq2rel_ds import msg
from seq2rel_ds.common import util
from seq2rel_ds.common.schemas import PubtatorAnnotation
from seq2rel_ds.common.util import EntityHinting

app = typer.Typer()

BC5CDR_URL = "https://biocreative.bioinformatics.udel.edu/media/store/files/2016/CDR_Data.zip"
MESH_TREE_URL = (
"https://github.com/fenchri/edge-oriented-graph/raw/master/data_processing/2017MeshTree.txt"
)
PARENT_DIR = "CDR_Data/CDR.Corpus.v010516"
TRAIN_FILENAME = "CDR_TrainingSet.PubTator.txt"
VALID_FILENAME = "CDR_DevelopmentSet.PubTator.txt"
TEST_FILENAME = "CDR_TestSet.PubTator.txt"


@lru_cache()
def _download_mesh_tree() -> Dict[str, List[str]]:
"""Downloads the MeSH tree and returns a dictionary mapping MeSH unique IDs to tree numbers."""
parsed_mesh_tree = defaultdict(list)
raw_mesh_tree = requests.get(MESH_TREE_URL).text.strip().splitlines()[1:]
for line in raw_mesh_tree:
tree_numbers, mesh_unique_id, _ = line.split("\t")
parsed_mesh_tree[mesh_unique_id].append(tree_numbers)
return parsed_mesh_tree


def _download_corpus() -> Tuple[str, str, str]:
z = util.download_zip(BC5CDR_URL)
train = z.read(str(Path(PARENT_DIR) / TRAIN_FILENAME)).decode()
Expand All @@ -24,17 +43,76 @@ def _download_corpus() -> Tuple[str, str, str]:
return train, valid, test


def _filter_hypernyms(pubtator_annotations: List[PubtatorAnnotation]) -> None:
"""For each document in `pubtator_annotations`, determines any possible negative relations
which are hypernyms of the positive relations. If found, these are appended to
`pubtator_annotations.filtered_relations`.
"""
# Download the MeSH tree which allows us to determine hypernyms for disease entities.
mesh_tree = _download_mesh_tree()

# Determine the entity and relation labels by looping until we find an document with relations.
for annotation in pubtator_annotations:
if annotation.relations:
chem_id, diso_id, rel_label = annotation.relations[0]
chem_label = annotation.clusters[chem_id].label
diso_label = annotation.clusters[diso_id].label
break

for annotation in pubtator_annotations:
# We will add this attribute to each annotation, regardless of whether or not it has
# relations to filter. This will mean that all examples in the dataset will be formatted
# the same way, which simplifies data loading.
annotation.filtered_relations = []
# Determine the negative relations by taking the set of the product of all unique chemical
# and disease entities, minus the set of all positive relations.
chemicals = [
ent_id for ent_id, ann in annotation.clusters.items() if ann.label == chem_label
]
diseases = [
ent_id for ent_id, ann in annotation.clusters.items() if ann.label == diso_label
]
all_relations = [
(chem, diso, rel_label) for chem, diso in itertools.product(chemicals, diseases)
]
negative_relations = list(set(all_relations) - set(annotation.relations))
# If any negative relation contains a chemical entity that matches the chemical entity of
# a positive relation AND its disease entity is a hypernym of the positive relations disease
# entity, this negative relation should be filtered.
for neg_chem, neg_diso, _ in negative_relations:
for pos_chem, pos_diso, _ in annotation.relations:
if neg_chem == pos_chem:
if any(
neg_tree_number in pos_tree_number
for pos_tree_number in mesh_tree[pos_diso]
for neg_tree_number in mesh_tree[neg_diso]
):
filtered_rel = (neg_chem, neg_diso, rel_label)
if filtered_rel not in annotation.filtered_relations:
annotation.filtered_relations.append(filtered_rel)


def _preprocess(
pubtator_content: str,
sort_rels: bool = True,
entity_hinting: Optional[EntityHinting] = None,
filter_hypernyms: bool = False,
) -> List[str]:
kwargs = {"concepts": ["chemical", "disease"], "skip_malformed": True} if entity_hinting else {}

pubtator_annotations = util.parse_pubtator(
pubtator_content=pubtator_content,
text_segment=util.TextSegment.both,
)

# This is unique the the BC5CDR corpus, which contains many negative relations that are
# actually valid, but are not annotated because they contain a disease entity which is the
# hypernym of a disease entity in a positive relation. We need to filter these out before
# evaluation, so this function finds all such cases and adds them to the filtered_relations
# field of the annoations.
if filter_hypernyms:
_filter_hypernyms(pubtator_annotations)

seq2rel_annotations = util.pubtator_to_seq2rel(
pubtator_annotations,
sort_rels=sort_rels,
Expand All @@ -59,6 +137,9 @@ def main(
),
case_sensitive=False,
),
combine_train_valid: bool = typer.Option(
False, help="Combine the train and validation sets into one train set."
),
) -> None:
"""Download and preprocess the BC5CDR corpus for use with seq2rel."""
msg.divider("Preprocessing BC5CDR")
Expand All @@ -75,16 +156,26 @@ def main(
msg.info("Entity hints will be inserted into the source text using the gold annotations.")

with msg.loading("Preprocessing the data..."):
if combine_train_valid:
msg.info("Training and validation sets will be combined into one train set.")
train_raw = f"{train_raw.strip()}\n\n{valid_raw.strip()}"
valid = None
else:
valid = _preprocess(
valid_raw, sort_rels=sort_rels, entity_hinting=entity_hinting, filter_hypernyms=True
)
train = _preprocess(train_raw, sort_rels=sort_rels, entity_hinting=entity_hinting)
valid = _preprocess(valid_raw, sort_rels=sort_rels, entity_hinting=entity_hinting)
test = _preprocess(test_raw, sort_rels=sort_rels, entity_hinting=entity_hinting)
test = _preprocess(
test_raw, sort_rels=sort_rels, entity_hinting=entity_hinting, filter_hypernyms=True
)
msg.good("Preprocessed the data.")

output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

(output_dir / "train.tsv").write_text("\n".join(train))
(output_dir / "valid.tsv").write_text("\n".join(valid))
if valid:
(output_dir / "valid.tsv").write_text("\n".join(valid))
(output_dir / "test.tsv").write_text("\n".join(test))
msg.good(f"Preprocessed data saved to {output_dir.resolve()}.")

Expand Down
Loading

0 comments on commit 96aa592

Please sign in to comment.