Skip to content

Commit

Permalink
add interactive ranking mode for humans (facebookresearch#1296)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexholdenmiller authored Nov 28, 2018
1 parent 09d82ae commit bdee9fc
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 8 deletions.
14 changes: 7 additions & 7 deletions parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,13 @@ def report(self):
self.metrics['f1'] / max(1, self.metrics['f1_cnt']),
4
)
if self.flags['has_text_cands']:
for k in self.eval_pr:
m['hits@' + str(k)] = round_sigfigs(
self.metrics['hits@' + str(k)] /
max(1, self.metrics['hits@_cnt']),
3
)
if self.flags['has_text_cands']:
for k in self.eval_pr:
m['hits@' + str(k)] = round_sigfigs(
self.metrics['hits@' + str(k)] /
max(1, self.metrics['hits@_cnt']),
3
)
for k in self.metrics_list:
if self.metrics[k + '_cnt'] > 0 and k != 'correct' and k != 'f1':
m[k] = round_sigfigs(
Expand Down
2 changes: 1 addition & 1 deletion parlai/scripts/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Basic example which allows local human keyboard input to talk to a trained model.
"""Basic script which allows local human keyboard input to talk to a trained model.
Examples
--------
Expand Down
103 changes: 103 additions & 0 deletions parlai/scripts/interactive_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python3

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

"""Does human evaluation on a task with label_candidates.
Human can exit with ctrl + c and metrics will be computed and displayed.
Examples
--------
.. code-block:: shell
python examples/interactive_rank.py -t babi:task10k:1 -dt valid
When prompted, enter the index of the label_candidate you think is correct.
Candidates are shuffled for each example.
During datatype train, examples are randomly sampled with replacement; use
train:ordered to not repeat examples.
During datatype valid or test, examples are shown in order, not shuffled.
"""
from parlai.core.metrics import Metrics
from parlai.core.params import ParlaiParser
from parlai.core.agents import create_agent, create_task_agent_from_taskname

import random


def setup_args(parser=None):
if parser is None:
parser = ParlaiParser()
parser.set_params(model='parlai.agents.local_human.local_human:LocalHumanAgent')
return parser


def interactive_rank(opt, print_parser=None):
# Create model and assign it to the specified task
human = create_agent(opt)
task = create_task_agent_from_taskname(opt)[0]

metrics = Metrics(opt)
episodes = 0

def print_metrics():
report = metrics.report()
report['episodes'] = episodes
print(report)

# Show some example dialogs:
try:
while not task.epoch_done():
msg = task.act()
print('[{id}]: {text}'.format(id=task.getID(), text=msg.get('text', '')))
cands = list(msg.get('label_candidates', []))
random.shuffle(cands)
for i, c in enumerate(cands):
print(' [{i}]: {c}'.format(i=i, c=c))

print('[ Please choose a response from the list. ]')

choice = None
while choice is None:
choice = human.act().get('text')
try:
choice = int(choice)
if choice >= 0 and choice < len(cands):
choice = cands[choice]
else:
print('[ Try again: you selected {i} but the '
'candidates are indexed from 0 to {j}. ]'
''.format(i=choice, j=len(cands) - 1))
choice = None
except (TypeError, ValueError):
print('[ Try again: you did not enter a valid index. ]')
choice = None

print('[ You chose ]: {}'.format(choice))
reply = {'text_candidates': [choice]}
labels = msg.get('eval_labels', msg.get('labels'))
metrics.update(reply, labels)
if msg.get('episode_done'):
episodes += 1
print_metrics()
print('------------------------------')
print('[ True reply ]: {}'.format(labels[0]))
if msg.get('episode_done'):
print('******************************')

except KeyboardInterrupt:
pass

print()
print_metrics()


if __name__ == '__main__':
random.seed(42)
parser = setup_args()
interactive_rank(parser.parse_args(print_args=False))

0 comments on commit bdee9fc

Please sign in to comment.