Skip to content

Commit

Permalink
pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
EmanuelaBoros committed Mar 3, 2023
1 parent 080e8af commit 3a886a8
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 47 deletions.
8 changes: 5 additions & 3 deletions genre/entity_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def get_trie_mention(sent, sent_orig):

pointer_start, _ = get_pointer_mention(sent)
if pointer_start + 1 < len(sent):
ment_next = mention_trie.get(sent[pointer_start + 1 :])
ment_next = mention_trie.get(sent[pointer_start + 1:])
else:
ment_next = mention_trie.get([])

Expand All @@ -217,7 +217,9 @@ def get_trie_mention(sent, sent_orig):
if sent_orig[pointer_end] != codes["EOS"]:
if sent_orig[pointer_end] in ment_next:
if codes["EOS"] in ment_next:
return [sent_orig[pointer_end], codes["end_mention_token"]]
return [
sent_orig[pointer_end],
codes["end_mention_token"]]
else:
return [sent_orig[pointer_end]]
elif codes["EOS"] in ment_next:
Expand All @@ -243,7 +245,7 @@ def get_trie_entity(sent, sent_orig):
pointer_start, pointer_end = get_pointer_mention(sent)

if pointer_start + 1 != pointer_end:
mention = decode_fn(sent[pointer_start + 1 : pointer_end]).strip()
mention = decode_fn(sent[pointer_start + 1: pointer_end]).strip()

if candidates_trie is not None:
candidates_trie_tmp = candidates_trie
Expand Down
24 changes: 18 additions & 6 deletions genre/fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def sample(
**kwargs,
) -> List[str]:
if isinstance(sentences, str):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
return self.sample(
[sentences],
beam=beam,
verbose=verbose,
**kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]

batched_hypos = self.generate(
Expand All @@ -51,9 +55,11 @@ def sample(
]

outputs = post_process_wikidata(
outputs, text_to_id=text_to_id, marginalize=marginalize, batched_hypos=batched_hypos,
marginalize_lenpen=marginalize_lenpen
)
outputs,
text_to_id=text_to_id,
marginalize=marginalize,
batched_hypos=batched_hypos,
marginalize_lenpen=marginalize_lenpen)

return outputs

Expand All @@ -72,12 +78,15 @@ def encode(self, sentence) -> torch.LongTensor:
else:
return tokens


class GENREHubInterface(_GENREHubInterface, BARTHubInterface):
pass



class mGENREHubInterface(_GENREHubInterface, BARTHubInterface):
pass


class GENRE(BARTModel):
@classmethod
def from_pretrained(
Expand All @@ -101,6 +110,7 @@ def from_pretrained(
)
return GENREHubInterface(x["args"], x["task"], x["models"][0])


class mGENRE(BARTModel):
@classmethod
def from_pretrained(
Expand All @@ -122,7 +132,9 @@ def from_pretrained(
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
sentencepiece_model=os.path.join(model_name_or_path, sentencepiece_model),
sentencepiece_model=os.path.join(
model_name_or_path,
sentencepiece_model),
**kwargs,
)
return mGENREHubInterface(x["args"], x["task"], x["models"][0])
5 changes: 3 additions & 2 deletions genre/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def sample(

outputs = chunk_it(
[
{"text": text, "score": score,}
{"text": text, "score": score, }
for text, score in zip(
self.tokenizer.batch_decode(
outputs.sequences, skip_special_tokens=True
Expand Down Expand Up @@ -92,5 +92,6 @@ class mGENRE(MBartForConditionalGeneration):
@classmethod
def from_pretrained(cls, model_name_or_path):
model = mGENREHubInterface.from_pretrained(model_name_or_path)
model.tokenizer = XLMRobertaTokenizer.from_pretrained(model_name_or_path)
model.tokenizer = XLMRobertaTokenizer.from_pretrained(
model_name_or_path)
return model
6 changes: 4 additions & 2 deletions genre/trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def add(self, sequence: List[int]):

def get(self, prefix_sequence: List[int]):
return Trie._get_from_trie(
prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
)
prefix_sequence,
self.trie_dict,
self.append_trie,
self.bos_token_id)

@staticmethod
def load_from_dict(trie_dict):
Expand Down
90 changes: 56 additions & 34 deletions genre/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
def chunk_it(seq, num):
assert num > 0
chunk_len = len(seq) // num
chunks = [seq[i * chunk_len : i * chunk_len + chunk_len] for i in range(num)]
chunks = [seq[i * chunk_len: i * chunk_len + chunk_len]
for i in range(num)]

diff = len(seq) - chunk_len * num
for i in range(diff):
Expand Down Expand Up @@ -57,31 +58,29 @@ def create_input(doc, max_length, start_delimiter, end_delimiter):
)
elif len(doc["meta"]["left_context"].split(" ")) <= max_length // 2:
input_ = (
doc["meta"]["left_context"]
+ " {} ".format(start_delimiter)
+ doc["meta"]["mention"]
+ " {} ".format(end_delimiter)
+ " ".join(
doc["meta"]["left_context"] +
" {} ".format(start_delimiter) +
doc["meta"]["mention"] +
" {} ".format(end_delimiter) +
" ".join(
doc["meta"]["right_context"].split(" ")[
: max_length - len(doc["meta"]["left_context"].split(" "))
]
)
)
: max_length -
len(
doc["meta"]["left_context"].split(" "))]))
elif len(doc["meta"]["right_context"].split(" ")) <= max_length // 2:
input_ = (
" ".join(
doc["meta"]["left_context"].split(" ")[
len(doc["meta"]["right_context"].split(" ")) - max_length :
]
)
+ " {} ".format(start_delimiter)
+ doc["meta"]["mention"]
+ " {} ".format(end_delimiter)
+ doc["meta"]["right_context"]
)
len(
doc["meta"]["right_context"].split(" ")) -
max_length:]) +
" {} ".format(start_delimiter) +
doc["meta"]["mention"] +
" {} ".format(end_delimiter) +
doc["meta"]["right_context"])
else:
input_ = (
" ".join(doc["meta"]["left_context"].split(" ")[-max_length // 2 :])
" ".join(doc["meta"]["left_context"].split(" ")[-max_length // 2:])
+ " {} ".format(start_delimiter)
+ doc["meta"]["mention"]
+ " {} ".format(end_delimiter)
Expand Down Expand Up @@ -120,7 +119,8 @@ def get_entity_spans_post_processing(sentences):
sent = re.sub(r"\. \. \} \[ (.*?) \]", r". } [ \1 ] .", sent)
sent = re.sub(r"\, \} \[ (.*?) \]", r" } [ \1 ] ,", sent)
sent = re.sub(r"\; \} \[ (.*?) \]", r" } [ \1 ] ;", sent)
sent = sent.replace("{ ", "{").replace(" } [ ", "}[").replace(" ]", "]")
sent = sent.replace(
"{ ", "{").replace(" } [ ", "}[").replace(" ]", "]")
outputs.append(sent)

return outputs
Expand Down Expand Up @@ -187,7 +187,10 @@ def get_entity_spans_hf(
)


def get_entity_spans_finalize(input_sentences, output_sentences, redirections=None):
def get_entity_spans_finalize(
input_sentences,
output_sentences,
redirections=None):

return_outputs = []
for input_, output_ in zip(input_sentences, output_sentences):
Expand Down Expand Up @@ -268,7 +271,7 @@ def get_markdown(sentences, entity_spans):
for begin, length, href in entities:
text += sent[last_end:begin]
text += "[{}](https://en.wikipedia.org/wiki/{})".format(
sent[begin : begin + length], href
sent[begin: begin + length], href
)
last_end = begin + length

Expand Down Expand Up @@ -338,9 +341,8 @@ def get_micro_recall(guess_entities, gold_entities, mode="strong"):
def get_micro_f1(guess_entities, gold_entities, mode="strong"):
precision = get_micro_precision(guess_entities, gold_entities, mode)
recall = get_micro_recall(guess_entities, gold_entities, mode)
return (
(2 * (precision * recall) / (precision + recall)) if precision + recall else 0
)
return ((2 * (precision * recall) / (precision + recall))
if precision + recall else 0)


def get_doc_level_guess_gold_entities(guess_entities, gold_entities):
Expand Down Expand Up @@ -382,8 +384,10 @@ def get_macro_f1(guess_entities, gold_entities, mode="strong"):
guess_entities, gold_entities
)
all_scores = [
get_micro_f1(guess_entities[k], gold_entities[k], mode) for k in guess_entities
]
get_micro_f1(
guess_entities[k],
gold_entities[k],
mode) for k in guess_entities]
return (sum(all_scores) / len(all_scores)) if len(all_scores) else 0


Expand All @@ -400,16 +404,15 @@ def extract_pages(filename):
# CASE 2: end of the document
elif line.startswith("</doc>"):
assert doc["id"] not in docs, "{} ({}) already in dict as {}".format(
doc["id"], doc["title"], docs[doc["id"]]["title"]
)
doc["id"], doc["title"], docs[doc["id"]]["title"])
docs[doc["id"]] = doc

# CASE 3: in the document
else:
doc["paragraphs"].append("")
try:
line = BeautifulSoup(line, "html.parser")
except:
except BaseException:
print("error line `{}`".format(line))
line = [line]

Expand Down Expand Up @@ -466,7 +469,11 @@ def search_wikidata(query, label_alias2wikidataID):


def get_wikidata_ids(
anchor, lang, lang_title2wikidataID, lang_redirect2title, label_or_alias2wikidataID,
anchor,
lang,
lang_title2wikidataID,
lang_redirect2title,
label_or_alias2wikidataID,
):
success, result = search_simple(anchor, lang, label_or_alias2wikidataID)
if success:
Expand All @@ -478,7 +485,8 @@ def get_wikidata_ids(
if success:
return result, "wikipedia"
else:
return search_wikidata(result, label_or_alias2wikidataID), "wikidata"
return search_wikidata(
result, label_or_alias2wikidataID), "wikidata"


def post_process_wikidata(outputs, text_to_id=False, marginalize=False,
Expand All @@ -491,7 +499,9 @@ def post_process_wikidata(outputs, text_to_id=False, marginalize=False,
]

if marginalize:
for (i, hypos), hypos_tok in zip(enumerate(outputs), batched_hypos):
for (
i, hypos), hypos_tok in zip(
enumerate(outputs), batched_hypos):
outputs_dict = defaultdict(list)
for hypo, hypo_tok in zip(hypos, hypos_tok):
outputs_dict[hypo["id"]].append(
Expand Down Expand Up @@ -522,7 +532,19 @@ def post_process_wikidata(outputs, text_to_id=False, marginalize=False,
return outputs


tr2016_langs = ["ar", "de", "es", "fr", "he", "it", "ta", "th", "tl", "tr", "ur", "zh"]
tr2016_langs = [
"ar",
"de",
"es",
"fr",
"he",
"it",
"ta",
"th",
"tl",
"tr",
"ur",
"zh"]

news_langs = [
"ar",
Expand Down

0 comments on commit 3a886a8

Please sign in to comment.