Skip to content

Commit

Permalink
Default datafile to datatype if not provided. (facebookresearch#3140)
Browse files Browse the repository at this point in the history
* Default datafile to datatype if not provided.

* Beef up docs.

* Typo.

* Some more refactoring

* Partial update.

* Fix location of error.

* Good test.

* Refactor error message.
  • Loading branch information
stephenroller authored Nov 8, 2020
1 parent 01d059d commit 7415a04
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 6 deletions.
3 changes: 3 additions & 0 deletions docs/source/tutorial_task.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ class SquadTeacher(DialogTeacher):
self.datatype = opt['datatype']
build(opt) # NOTE: the call to build here
suffix = 'train' if opt['datatype'].startswith('train') else 'dev'
# whatever is placed into datafile will be passed as the argument to
# setup_data in the next section.
opt['datafile'] = os.path.join(opt['datapath'], 'SQuAD', suffix + '-v1.1.json')
self.id = 'squad'
super().__init__(opt, shared)
Expand Down Expand Up @@ -361,6 +363,7 @@ The sample `setup_data` method for our task is presented below.

```python
def setup_data(self, path):
# note that path is the value provided by opt['datafile']
print('loading: ' + path)
with PathManager.open(path) as data_file:
self.squad = json.load(data_file)['data']
Expand Down
55 changes: 49 additions & 6 deletions parlai/core/teachers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@
from typing import List, Tuple, Optional, TypeVar


ERROR_MESSAGE_NO_DATAFILE = (
"{class_name} is expected to set self.opt['datafile'] inside `__init__` "
"before calling `super().__init__`. This will passed to setup_data, "
"indicating what data to load. If you don't know what to use, set "
"`opt['datafile'] = parlai.utils.data.DatatypeHelper.fold(opt['datatype'])` "
"to receive the fold name in setup_data."
)


ChunkOutput = TypeVar('ChunkOutput')


Expand Down Expand Up @@ -560,6 +569,10 @@ def __init__(self, opt, shared=None):
if shared and shared.get('data'):
self.data = data_class(opt, shared=shared['data'], **kwargs)
else:
if 'datafile' not in self.opt:
raise KeyError(
ERROR_MESSAGE_NO_DATAFILE.format(class_name=self.__class__.__name__)
)
self.data = data_class(
opt,
data_loader=self.setup_data,
Expand All @@ -569,6 +582,26 @@ def __init__(self, opt, shared=None):

self.reset()

@abstractmethod
def setup_data(self, datafile: str):
"""
The core method which the user should override.
Yields the data, one message at a time, as well as markers indicating
new episodes.
:param str datafile:
If the initializer set a 'datafile' field within the initalization,
this will be provided here. Otherwise, datafile will be the fold:
either "train", "valid", or "test".
:return:
Yields pairs (message, new_episode) containing a Message object
and whether the message marks the beginning of a totally new
episode.
"""
pass

def reset(self):
"""
Reset the dialog to the start of the epoch, reset all metrics.
Expand Down Expand Up @@ -696,6 +729,12 @@ def __init__(self, opt, data_loader=None, cands=None, shared=None, **kwargs):
else:
self.image_loader = ImageLoader(opt)
self.data = []

if 'datafile' not in opt:
raise KeyError(
ERROR_MESSAGE_NO_DATAFILE.format(class_name=self.__class__.__name__)
)

self._load(data_loader, opt['datafile'])
self.cands = None if cands is None else set(c for c in cands)

Expand Down Expand Up @@ -914,6 +953,10 @@ def __init__(self, opt, data_loader=None, cands=None, shared=None, **kwargs):
else:
# main instance holds the stream and shares pointer to it
self.data_loader = data_loader
if 'datafile' not in opt:
raise KeyError(
ERROR_MESSAGE_NO_DATAFILE.format(class_name=self.__class__.__name__)
)
self.datafile = opt['datafile']
self.reset_data = None
self.is_reset = True
Expand All @@ -924,8 +967,8 @@ def __init__(self, opt, data_loader=None, cands=None, shared=None, **kwargs):

self.rank = get_rank()
self.num_workers = num_workers()
self.is_distributed_and_is_eval = self.num_workers > 1 and any(
x in opt['datatype'] for x in ('valid', 'test', 'train:evalmode')
self.is_distributed_and_is_eval = (
self.num_workers > 1 and not DatatypeHelper.is_training(opt['datatype'])
)

def share(self):
Expand Down Expand Up @@ -1599,7 +1642,7 @@ def __init__(self, opt, shared=None):
self.task = opt['task'].split(':')[1] if ':' in opt['task'] else opt['task']
self.data_path = self.get_data_path(opt)
self.data = self.load_data(self.data_path, self.opt)
self.datatype = opt.get('datatype').split(':')[0]
self.datatype = DatatypeHelper.fold(opt['datatype'])

# Example of available models: 'resnet152', 'resnext101_32x48d_wsl',
# and ImageLoader supports other resnet and resnext models too
Expand Down Expand Up @@ -1779,7 +1822,7 @@ def load_data(self, data_path, opt):
Can be override by subclass.
"""

dt = opt['datatype'].split(':')[0]
dt = DatatypeHelper.fold(opt['datatype'])

# Sometimes file is named "val" instead of "valid"
if dt not in ['train', 'valid', 'val', 'test']:
Expand Down Expand Up @@ -1977,7 +2020,7 @@ def __init__(self, opt: Opt, shared=None):
self.tasks.extend(create_task_agent_from_taskname(opt_singletask))
self.task_idx = -1
self.new_task = True
self.random = opt.get('datatype') == 'train'
self.random = DatatypeHelper.should_shuffle(opt['datatype'])
# Make multi-task task probabilities.
self.cum_task_weights = [1] * len(self.tasks)
self.task_choices = range(len(self.tasks))
Expand Down Expand Up @@ -2158,7 +2201,7 @@ def __init__(self, opt, shared=None):
def _get_data_folder(self):
if not self.opt.get('datafile'):
raise RuntimeError(
'Must specify datafile or override this function '
'Must specify datafile or override this function (_get_data_folder) '
'to return the data folder.'
)

Expand Down
15 changes: 15 additions & 0 deletions parlai/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@ class DatatypeHelper:
Helper class to determine properties from datatype strings.
"""

@classmethod
def fold(cls, datatype: str) -> str:
"""
Extract the fold part of the datatype.
:param datatype:
parlai datatype
:return: the fold
>>> DatatypeHelper.fold("train:ordered")
... "train"
"""
return datatype.split(':')[0]

@classmethod
def should_cycle(cls, datatype: str) -> bool:
"""
Expand Down
18 changes: 18 additions & 0 deletions tests/test_teachers.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,30 @@ def setup_data(self, datafile):
yield Message({'text': str(j), 'label': str(j * 2)}), j == 1


class NoDatafileTeacher(DialogTeacher):
def setup_data(self, datafile):
yield Message({'text': datafile, 'label': datafile}), True


class ViolationTeacher(_MockTeacher):
def setup_data(self, datafile):
yield {'text': 'foo', 'episode_done': True}, True


class TestDialogTeacher(unittest.TestCase):
def test_nodatafile(self):
for dt in [
'train:ordered',
'train:stream:ordered',
'valid',
'test',
'valid:stream',
'test:stream',
]:
opt = Opt({'datatype': dt, 'datapath': '/tmp', 'task': 'test'})
with self.assertRaises(KeyError):
NoDatafileTeacher(opt)

def _verify_act(self, act, goal_text, goal_label, episode_done):
assert 'eval_labels' in act or 'labels' in act
labels = act.get('labels', act.get('eval_labels'))
Expand Down
53 changes: 53 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from copy import deepcopy
import time
import unittest
from parlai.utils.data import DatatypeHelper


class TestUtils(unittest.TestCase):
Expand Down Expand Up @@ -139,5 +140,57 @@ def test_uppercase(self):
assert string_utils.uppercase("tEst") == "TEst"


class TestDatatypeHelper(unittest.TestCase):
def test_fold(self):
assert DatatypeHelper.fold("train") == "train"
assert DatatypeHelper.fold("train:ordered") == "train"
assert DatatypeHelper.fold("train:stream") == "train"
assert DatatypeHelper.fold("train:stream:ordered") == "train"
assert DatatypeHelper.fold("train:evalmode") == "train"
assert DatatypeHelper.fold("train:stream:evalmode") == "train"

assert DatatypeHelper.fold("valid") == "valid"
assert DatatypeHelper.fold("valid:stream") == "valid"

assert DatatypeHelper.fold("test") == "test"
assert DatatypeHelper.fold("test:stream") == "test"

def test_should_cycle(self):
assert DatatypeHelper.should_cycle("train") is True
assert DatatypeHelper.should_cycle("train:evalmode") is False
assert DatatypeHelper.should_cycle("train:ordered") is False
assert DatatypeHelper.should_cycle("train:stream") is True

assert DatatypeHelper.should_cycle("valid") is False
assert DatatypeHelper.should_cycle("valid:stream") is False

assert DatatypeHelper.should_cycle("test") is False
assert DatatypeHelper.should_cycle("test:stream") is False

def test_should_shuffle(self):
assert DatatypeHelper.should_shuffle("train") is True
assert DatatypeHelper.should_shuffle("train:evalmode") is False
assert DatatypeHelper.should_shuffle("train:ordered") is False
assert DatatypeHelper.should_shuffle("train:stream") is False

assert DatatypeHelper.should_shuffle("valid") is False
assert DatatypeHelper.should_shuffle("valid:stream") is False

assert DatatypeHelper.should_shuffle("test") is False
assert DatatypeHelper.should_shuffle("test:stream") is False

def test_is_training(self):
assert DatatypeHelper.is_training("train") is True
assert DatatypeHelper.is_training("train:evalmode") is False
assert DatatypeHelper.is_training("train:ordered") is True
assert DatatypeHelper.is_training("train:stream") is True

assert DatatypeHelper.is_training("valid") is False
assert DatatypeHelper.is_training("valid:stream") is False

assert DatatypeHelper.is_training("test") is False
assert DatatypeHelper.is_training("test:stream") is False


if __name__ == '__main__':
unittest.main()

0 comments on commit 7415a04

Please sign in to comment.