Skip to content

Commit

Permalink
Add MultiNLI task (facebookresearch#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
apsdehal authored and alexholdenmiller committed Jan 3, 2018
1 parent f16622f commit f9ee62e
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 0 deletions.
5 changes: 5 additions & 0 deletions parlai/tasks/multinli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
73 changes: 73 additions & 0 deletions parlai/tasks/multinli/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from parlai.core.teachers import DialogTeacher
from .build import build

import os
import copy
import json


MULTINLI = 'MultiNLI'
MULTINLI_VERSION = '1.0'
MULTINLI_PREFIX = 'multinli_'
MULTINLI_PREMISE_PREFIX = 'Premise: '
MULTINLI_HYPO_PREFIX = 'Hypothesis: '
MULTINLI_LABELS = ['entailment', 'contradiction', 'neutral']
MULTINLI_PREMISE_KEY = 'sentence1'
MULTINLI_HYPO_KEY = 'sentence2'
MULTINLI_ANSWER_KEY = 'gold_label'


def _path(opt):
build(opt)

dt = opt['datatype'].split(':')[0]

if dt == 'train':
suffix = 'train'
# Using matched set as valid and mismatched set as test
elif dt == 'valid':
suffix = 'dev_matched'
elif dt == 'test':
suffix = 'dev_mismatched'
else:
raise RuntimeError('Not valid datatype.')

data_path = os.path.join(opt['datapath'], MULTINLI,
MULTINLI_PREFIX + MULTINLI_VERSION,
MULTINLI_PREFIX + MULTINLI_VERSION +
'_' + suffix + '.jsonl')

return data_path


class DefaultTeacher(DialogTeacher):
def __init__(self, opt, shared=None):
opt = copy.deepcopy(opt)
data_path = _path(opt)
opt['datafile'] = data_path
self.id = 'MultiNLI'

super().__init__(opt, shared)

def setup_data(self, path):
print('loading: ' + path)

with open(path, 'r') as data_file:
for pair_line in data_file:
pair = json.loads(pair_line)
premise = MULTINLI_PREMISE_PREFIX + pair[MULTINLI_PREMISE_KEY]
hypo = MULTINLI_HYPO_PREFIX + pair[MULTINLI_HYPO_KEY]
answer = [pair[MULTINLI_ANSWER_KEY]]

if answer == '-':
continue

question = premise + '\n' + hypo

yield (question, answer, None, MULTINLI_LABELS), True
37 changes: 37 additions & 0 deletions parlai/tasks/multinli/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

import parlai.core.build_data as build_data
import os


MULTINLI_BASE_URL = 'https://www.nyu.edu/projects/bowman/multinli/'


def build(opt):
dpath = os.path.join(opt['datapath'], 'MultiNLI')
version = '1.0'

if not build_data.built(dpath, version_string=version):
print('[building data: ' + dpath + ']')

if build_data.built(dpath):
# an older version exists, so remove these outdated files.
build_data.remove_dir(dpath)
build_data.make_dir(dpath)

# download the data.
fname = 'multinli_' + version + '.zip'
# dataset URL
url = MULTINLI_BASE_URL + fname
build_data.download(url, dpath, fname)

# uncompress it
build_data.untar(dpath, fname)

# mark the data as built
build_data.mark_done(dpath, version_string=version)
7 changes: 7 additions & 0 deletions parlai/tasks/task_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@
"tags": [ "All", "QA" ],
"description": "Closed-domain QA dataset asking MTurk-derived questions about movies, answerable from Wikipedia. From Li et al. '16. Link: https://arxiv.org/abs/1611.09823"
},
{
"id": "MultiNLI",
"display_name": "MultiNLI",
"task": "multinli",
"tags": [ "All", "Entailment" ],
"description": "A dataset designed for use in the development and evaluation of machine learning models for sentence understanding. Each example contains a premise and hypothesis. Model has to predict whether premise and hypothesis entail, contradict or are neutral to each other. From Williams et al. '17. Link: https://arxiv.org/abs/1704.05426"
},
{
"id": "OpenSubtitles",
"display_name": "Open Subtitles",
Expand Down
23 changes: 23 additions & 0 deletions tests/test_downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,29 @@ def test_ms_marco(self):

shutil.rmtree(self.TMP_PATH)

def test_multinli(self):
from parlai.core.params import ParlaiParser
from parlai.tasks.multinli.agents import DefaultTeacher

opt = ParlaiParser().parse_args(args=self.args)

for dt in ['train', 'valid', 'test']:
opt['datatype'] = dt

teacher = DefaultTeacher(opt)
reply = teacher.act()
check(opt, reply)
assert len(reply.get('label_candidates')) == 3
assert reply.get('text').find('Premise') != -1
assert reply.get('text').find('Hypothesis') != -1

if dt == 'train':
assert reply.get('labels')[0] in ['entailment',
'contradiction',
'neutral']

shutil.rmtree(self.TMP_PATH)


if __name__ == '__main__':
# clean out temp dir first
Expand Down

0 comments on commit f9ee62e

Please sign in to comment.