forked from facebookresearch/ParlAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ACUTE-Eval] Fast Acute OSS Part 2 - Everything Else (facebookresearc…
…h#2573) * fast acute OSS * autoformat * readme changes * remove todos * typing * incorporate matchups-per-pair arg * Update README.md * readme update
- Loading branch information
Showing
7 changed files
with
1,017 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Model Configuration file for Fast ACUTE Eval. | ||
CONFIG: Dict[str, Dict] | ||
- maps ids to their appropriate options | ||
- for models, please only include options that you would specify on the command line | ||
""" | ||
import os | ||
from typing import Dict | ||
|
||
ROOT_DIR = '/checkpoint/parlai/acute_evals/' | ||
CONFIG: Dict[str, Dict] = { | ||
'example_model_1': { | ||
'model_file': 'zoo:tutorial_transformer_generator/model', | ||
'model': 'transformer/generator', | ||
# general args | ||
'batchsize': 1, | ||
'skip_generation': False, | ||
'interactive_mode': False, | ||
'beam_size': 3, | ||
'beam_min_length': 3, | ||
'inference': 'beam', | ||
'beam_block_ngram': 3, | ||
'beam_context_block_ngram': 3, | ||
}, | ||
'example_model_2': { | ||
'model_file': 'zoo:tutorial_transformer_generator/model', | ||
'model': 'transformer/generator', | ||
# general args | ||
'batchsize': 1, | ||
'skip_generation': False, | ||
'interactive_mode': False, | ||
'inference': 'nucleus', | ||
'topp': 0.9, | ||
}, | ||
'example_model_log': { | ||
'log_path': f"{os.path.dirname(os.path.realpath(__file__))}/example/chat_log.jsonl" | ||
}, | ||
'example_dataset': {'task': 'convai2', 'prepended_context': True}, | ||
} |
140 changes: 140 additions & 0 deletions
140
parlai/mturk/tasks/acute_eval/dump_task_to_acute_format.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
""" | ||
Convert a ParlAI teacher to acute-eval format. | ||
Examples | ||
-------- | ||
.. code-block:: shell | ||
py parlai/mturk/tasks/acute_eval/dump_task_to_acute_format.py -t convai2 | ||
""" | ||
|
||
from parlai.core.params import ParlaiParser | ||
from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent | ||
from parlai.core.worlds import create_task | ||
from parlai.utils.conversations import Conversations | ||
from parlai.utils.misc import TimeLogger | ||
import random | ||
import tempfile | ||
|
||
|
||
def setup_args(): | ||
""" | ||
Set up conversion args. | ||
""" | ||
parser = ParlaiParser() | ||
parser.add_argument( | ||
'-n', | ||
'--num-episodes', | ||
default=-1, | ||
type=int, | ||
help='Total number of episodes to convert, -1 to convert \ | ||
all examples', | ||
) | ||
parser.add_argument( | ||
'-of', | ||
'--outfile', | ||
default=None, | ||
type=str, | ||
help='Output file where to save, by default will be \ | ||
created in /tmp', | ||
) | ||
parser.add_argument( | ||
'-s1id', '--speaker-0-id', type=str, help='Speaker id of agent who speaks first' | ||
) | ||
parser.add_argument( | ||
'-s1id', | ||
'--speaker-1-id', | ||
type=str, | ||
help='Speaker id of agent who speaks second', | ||
) | ||
parser.add_argument( | ||
'--prepended-context', | ||
type='bool', | ||
default=False, | ||
help='specify if the context is prepended to the first act', | ||
) | ||
parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10) | ||
parser.set_defaults(datatype='train:ordered') | ||
|
||
return parser | ||
|
||
|
||
def dump_data(opt): | ||
""" | ||
Dump task data to ACUTE-Eval. | ||
""" | ||
# create repeat label agent and assign it to the specified task | ||
agent = RepeatLabelAgent(opt) | ||
world = create_task(opt, agent) | ||
task = opt.get('task') | ||
speaker_0_id = opt.get('speaker_0_id') or f'{task}_as_human' | ||
speaker_1_id = opt.get('speaker_1_id') or f'{task}_as_model' | ||
if opt['outfile'] is None: | ||
outfile = tempfile.mkstemp( | ||
prefix='{}_{}_'.format(opt['task'], opt['datatype']), suffix='.txt' | ||
)[1] | ||
else: | ||
outfile = opt['outfile'] | ||
|
||
num_episodes = ( | ||
world.num_episodes() | ||
if opt['num_episodes'] == -1 | ||
else min(opt['num_episodes'], world.num_episodes()) | ||
) | ||
log_timer = TimeLogger() | ||
|
||
print(f'[ starting to convert, saving output to {outfile} ]') | ||
dialogues = [] | ||
for _ in range(num_episodes): | ||
episode = [] | ||
episode_done = False | ||
while not episode_done: | ||
world.parley() | ||
acts = world.get_acts() | ||
text = acts[0].get('text') | ||
split_text = text.split('\n') | ||
label = random.choice( | ||
acts[0].get('labels', acts[0].pop('eval_labels', None)) | ||
) | ||
if not episode and opt.get('prepended_context'): | ||
# first turn | ||
context = split_text[:-1] | ||
text = split_text[-1] | ||
context_turn = [ | ||
{'text': context, 'episode_done': False, 'id': 'context'} | ||
for _ in range(2) | ||
] | ||
episode.append(context_turn) | ||
turn = [ | ||
{'text': text, 'episode_done': False, 'id': speaker_0_id}, | ||
{'text': label, 'episode_done': False, 'id': speaker_1_id}, | ||
] | ||
episode.append(turn) | ||
if acts[0].get('episode_done', False): | ||
episode[-1][-1]['episode_done'] = True | ||
episode_done = True | ||
dialogues.append(episode) | ||
|
||
if log_timer.time() > opt['log_every_n_secs']: | ||
text, _log = log_timer.log(world.total_parleys, world.num_examples()) | ||
print(text) | ||
|
||
if world.epoch_done(): | ||
break | ||
|
||
Conversations.save_conversations(dialogues, outfile, opt) | ||
|
||
|
||
def main(): | ||
random.seed(42) | ||
# Get command line arguments | ||
parser = setup_args() | ||
opt = parser.parse_args() | ||
dump_data(opt) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.