Skip to content

Commit

Permalink
Generative Wizard fixes & test (facebookresearch#1580)
Browse files Browse the repository at this point in the history
* Fix Wizard TwoStageAgent with old history format.
* Update readme.
* Add a GPU test for wizard.
  • Loading branch information
stephenroller authored Mar 26, 2019
1 parent c14d2b8 commit 8ab911a
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 38 deletions.
10 changes: 6 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ commands:
command: |
pip3 install --progress-bar off numpy
pip3 install --progress-bar off torch
pip3 install --progress-bar off 'git+https://github.com/rsennrich/subword-nmt.git#egg=subword-nmt' # bpe support
installtorchcpu:
description: "Install Torch (CPU)"
steps:
Expand Down Expand Up @@ -91,7 +92,7 @@ jobs:
- installtorchcpu
- run:
name: Data tests
command: python setup.py test -s tests.suites.datatests -q
command: python setup.py test -s tests.suites.datatests -v

unittests:
executor: standard_cpu
Expand All @@ -103,7 +104,7 @@ jobs:
- installtorchcpu
- run:
name: Unit tests
command: python setup.py test -s tests.suites.unittests -q
command: python setup.py test -s tests.suites.unittests -v

lint:
executor: standard_cpu
Expand Down Expand Up @@ -131,7 +132,8 @@ jobs:
- installdeps
- run:
name: Nightly GPU tests
command: python setup.py test -s tests.suites.nightly_gpu -q
no_output_timeout: 30m
command: python setup.py test -s tests.suites.nightly_gpu -v

nightly_cpu_tests:
executor: standard_cpu
Expand All @@ -143,7 +145,7 @@ jobs:
- installtorchcpu
- run:
name: All nightly CPU tests
command: python setup.py test -s tests.suites.nightly_cpu -q
command: python setup.py test -s tests.suites.nightly_cpu -v

deploy_website:
executor: standard_cpu
Expand Down
15 changes: 11 additions & 4 deletions projects/wizard_of_wikipedia/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,26 @@ Vanilla Transformer (no knowledge) | [Dinan et al. (2019)](https://arxiv.org/a
You can evaluate the pretrained End-to-end generative models via:

python examples/eval_model.py \
-bs 64 -t wizard_of_wikipedia:end2end_generator:random_split \
-mf models:wizard_of_wikipedia/wizard_generator/endtoend_model
-bs 64 -t wizard_of_wikipedia:generator:random_split \
-mf models:wizard_of_wikipedia/end2end_generator/model


This produces the following metrics:

{'f1': 0.1717, 'ppl': 61.21, 'know_acc': 0.2201, 'know_chance': 0.02625}

This differs slightly from the results in the paper, as it is a recreation trained
from scratch for public release.

You can also evaluate the model on the unseen topic split too:

python examples/eval_model.py \
-bs 64 -t wizard_of_wikipedia:end2end_generator:topic_split \
-mf models:wizard_of_wikipedia/wizard_generator/model
-bs 64 -t wizard_of_wikipedia:generator:topic_split \
-mf models:wizard_of_wikipedia/end2end_generator/model

This will produce:

{'f1': 0.1498, 'ppl': 103.1, 'know_acc': 0.1123, 'know_chance': 0.02496}

Check back later for more pretrained models soon!

Expand Down
50 changes: 20 additions & 30 deletions projects/wizard_of_wikipedia/generator/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,48 +58,38 @@ def batchify(self, obs_batch):
checked_sentences.append(checked_sentence)

batch['checked_sentence'] = checked_sentences

return batch


class TwoStageAgent(_GenericWizardAgent):
def __init__(self, opt, shared):
def __init__(self, opt, shared=None):
super().__init__(opt, shared)
if shared is not None:
# make sure the dialogue token appears
self.dict[TOKEN_DIALOG] = 9999999

def observe(self, obs):
def _set_text_vec(self, obs, history, truncate):
if 'text' not in obs:
return obs

# TODO: resolve this with #1421
# get the dialog stuff
reply = self.last_reply()
self.observation = self.get_dialog_history(obs, reply=reply)
# we need to store the old text so that we can restore it
oldtext = obs['text']

# now we want to force prepend the knowledge stuff
fields = []
if 'chosen_topic' in obs:
fields += [obs['title']]
if 'checked_sentence' in obs:
fields += [TOKEN_KNOWLEDGE, obs['checked_sentence']]
if obs['text'] != '':
fields += [TOKEN_DIALOG, obs['text']]
obs['text'] = ' '.join(fields)

# now vectorize with the extra knowledge. It'll all get stored in the
# text_vec operation, etc
self.vectorize(
obs,
text_truncate=self.text_truncate,
label_truncate=self.label_truncate
)

# finally we need to return the old text to the way it was
obs['text'] = oldtext
assert obs is self.observation
if 'text_vec' not in obs:
fields = []
dialogue_history = history.get_history_str()
if 'chosen_topic' in obs:
fields += [obs['title']]
if 'checked_sentence' in obs:
fields += [TOKEN_KNOWLEDGE, obs['checked_sentence']]
if dialogue_history:
fields += [TOKEN_DIALOG, dialogue_history]
obs['text'] = ' '.join(fields)
obs['text_vec'] = self.dict.txt2vec(obs['text'])

# check truncation
if 'text_vec' in obs:
obs['text_vec'] = th.LongTensor(
self._check_truncate(obs['text_vec'], truncate, True)
)

return obs

Expand Down
51 changes: 51 additions & 0 deletions tests/nightly/gpu/test_wizard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import unittest
import parlai.scripts.display_data as display_data
import parlai.core.testing_utils as testing_utils

END2END_OPTIONS = {
'task': 'wizard_of_wikipedia:generator:random_split',
'model_file': 'models:wizard_of_wikipedia/end2end_generator/model',
'batchsize': 32,
'log_every_n_secs': 30,
'embedding_type': 'random',
}


@testing_utils.skipUnlessGPU
class TestWizardModel(unittest.TestCase):
"""Checks that DrQA Model can be downloaded and achieve appropriate results"""
@classmethod
def setUpClass(cls):
# go ahead and download things here
with testing_utils.capture_output():
parser = display_data.setup_args()
parser.set_defaults(**END2END_OPTIONS)
opt = parser.parse_args(print_args=False)
opt['num_examples'] = 1
display_data.display_data(opt)

def test_end2end(self):
stdout, valid, _ = testing_utils.eval_model(END2END_OPTIONS)
self.assertEqual(
valid['ppl'], 61.21,
'valid ppl = {}\nLOG:\n{}'.format(valid['ppl'], stdout)
)
self.assertEqual(
valid['f1'], 0.1717,
'valid f1 = {}\nLOG:\n{}'.format(valid['f1'], stdout)
)
self.assertGreaterEqual(
valid['know_acc'], 0.2201,
'valid know_acc = {}\nLOG:\n{}'.format(valid['know_acc'], stdout)
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 8ab911a

Please sign in to comment.