Skip to content

Commit

Permalink
fixing reasoning perturbation issues (facebookresearch#4883)
Browse files Browse the repository at this point in the history
* fixing reasoning perturbation issues

* update looping and comments
  • Loading branch information
Golovneva authored Nov 17, 2022
1 parent 0f15b89 commit 4ffb29c
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 39 deletions.
32 changes: 26 additions & 6 deletions parlai/tasks/math_dataset/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.utils.io import PathManager
from typing import Optional
from typing import List, Optional

from parlai.tasks.reasoning.agents import MWPStepsReasoningTeacher

Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(self, opt, shared=None):
self.math_random = random.Random(42)
super().__init__(opt, shared)

def load_data(self, domains):
def load_data(self, domains) -> List[str]:
data = []
data_path = self.opt['datafile']
for domain in domains:
Expand Down Expand Up @@ -106,10 +106,7 @@ def get_data_for_fold(self, fold):
answer_blob = self._clean_steps(answer_blob)
steps = answer_blob.split(". ")
if extrinsic_step:
rand_steps = self._clean_steps(
self.math_random.choice(data)["solution"]
).split(". ")
random_step = self.math_random.choice(rand_steps)
random_step = self._find_nonempty_random_step(data)
if convert:
question = self._latex_conversion(question)
final_answer = self._latex_conversion(final_answer)
Expand Down Expand Up @@ -225,6 +222,29 @@ def _latex_conversion(self, final_answer: str) -> str:

return final_answer

def _find_nonempty_random_step(self, dataset: List[str]) -> str:
'''Here we *ASSUME* that the whole dataset contains at least one non-empty step
Otherwise it will go into infinite loop looking for the one
'''
# what we call an empty step
empty_steps = ["", " "]
# first find chain with at least one non-empty step
rand_steps = self._clean_steps(
self.math_random.choice(dataset)["solution"]
).split(". ")
# make sure this chain has at least one non-empty step
i = 0
while i < len(rand_steps) and rand_steps[i] in empty_steps:
i += 1
# if it doesn't, try again
if i == len(rand_steps):
return self._find_nonempty_random_step(dataset)
random_step = empty_steps[0]
# find non-empty random step (and we know it exists in this chain)
while random_step in empty_steps:
random_step = self.math_random.choice(rand_steps)
return random_step

def get_boxed_answer(self, answer):
boxed_idx = answer.find("boxed{")
final_answer = answer[boxed_idx:]
Expand Down
10 changes: 8 additions & 2 deletions parlai/tasks/proof_writer/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,14 @@ def get_data_for_fold(self, fold):

for m in messages:
if extrinsic_step:
rand_steps = self.proofwriter_random.choice(messages)["steps"]
random_step = self.proofwriter_random.choice(rand_steps)
random_step = None
# make sure new step is from a different context
# here we aasume that there is at least one step in the set
# with different context, otherwise it will go in the
# infinite loop
while not random_step or random_step in m["question"]:
rand_steps = self.proofwriter_random.choice(messages)["steps"]
random_step = self.proofwriter_random.choice(rand_steps)
m["extrinsic_step"] = random_step
yield m
else:
Expand Down
46 changes: 16 additions & 30 deletions parlai/tasks/reasoning/reason_types/step_by_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,29 +301,15 @@ def __init__(self, opt: Opt, cache: Optional[Dict[str, List[List[str]]]] = None)
super().__init__(opt, cache)
self.lemmatizer = WordNetLemmatizer()

def lemmatize_step(self, step):
try:
words = nltk.word_tokenize(str(step))
except IndexError:
print(
f"WARNING: could not lemmatize step {str(step)}. Proceeding to the next perturbation."
)
return str(step)
def lemmatize_step(self, words):
lemmatized_output = ' '.join([self.lemmatizer.lemmatize(w, 'v') for w in words])
# remove extraneous spaces after joining strings back
clean_lemmatized_output = re.sub(
r'\s([?.!"](?:\s|$))', r'\1', lemmatized_output
)
return clean_lemmatized_output

def drop_verb(self, step):
try:
words = nltk.word_tokenize(str(step))
except IndexError:
print(
f"WARNING: could not lemmatize step {str(step)}. Proceeding to the next perturbation."
)
return str(step)
def drop_verb(self, words):
tags = nltk.pos_tag(words)
verb_indices = []
for i, tag in enumerate(tags):
Expand All @@ -338,14 +324,7 @@ def drop_verb(self, step):
clean_result = re.sub(r'\s([?.!"](?:\s|$))', r'\1', result)
return clean_result

def swap_words(self, step):
try:
tokenized_step = nltk.word_tokenize(str(step))
except IndexError:
print(
f"WARNING: could not lemmatize step {str(step)}. Proceeding to next perturbation."
)
return str(step)
def swap_words(self, tokenized_step):
tags = nltk.pos_tag(tokenized_step)
word_indices = []
for i, tag in enumerate(tags):
Expand Down Expand Up @@ -374,15 +353,22 @@ def perturb(self, example_dict: Dict) -> Dict:
grammatical_error_steps = []

for i, step in enumerate(steps):
try:
tok_step = nltk.word_tokenize(str(step))
except IndexError:
print(
f"WARNING: could not tokenize step {str(step)}. Proceeding to next chain."
)
return str(step)
# perform all possible grammatical errors on each step, then randomly choose 1
lemmatized_step = self.lemmatize_step(step)
if str(step) != lemmatized_step:
lemmatized_step = self.lemmatize_step(tok_step)
if tok_step != lemmatized_step:
grammatical_error_steps.append((i, lemmatized_step))
dropped_verb_step = self.drop_verb(step)
if dropped_verb_step != "" and str(step) != dropped_verb_step:
dropped_verb_step = self.drop_verb(tok_step)
if dropped_verb_step != "" and tok_step != dropped_verb_step:
grammatical_error_steps.append((i, dropped_verb_step))
swapped_word_step = self.swap_words(step)
if swapped_word_step != "" and str(step) != swapped_word_step:
swapped_word_step = self.swap_words(tok_step)
if swapped_word_step != "" and tok_step != swapped_word_step:
grammatical_error_steps.append((i, swapped_word_step))

if not grammatical_error_steps:
Expand Down
4 changes: 3 additions & 1 deletion projects/roscoe/baselines/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def load(self, path=None):
)
# Path here to fine-tuend BART Model
try:
self.scorer.load(BART_SCORE_REPO + "/train/reproduce/trained/bart_6000.pth")
self.scorer.load(
BART_SCORE_REPO + "/train/reproduce/trained/fine_tuned_bartscore.pth"
)
except FileNotFoundError:
raise FileNotFoundError(
f"Path here should be to fine tuned BART model from"
Expand Down

0 comments on commit 4ffb29c

Please sign in to comment.