Skip to content

Commit

Permalink
fixes to pytorch data loader
Browse files Browse the repository at this point in the history
  • Loading branch information
klshuster committed Dec 21, 2017
1 parent 68a7ff1 commit 6cd52d8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
15 changes: 10 additions & 5 deletions examples/build_pytorch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,14 @@ def build_data(opt):
'have a datafile or `--datafile` is not set')

pytorch_datafile = datafile + ".pytorch"
if opt.get('preprocess', True):
preprocess = opt.get('pytorch_preprocess', True)
if preprocess:
pytorch_datafile += agent.getID()
if os.path.isfile(pytorch_datafile):
# Data already built
print("[ pytorch data already built. ]")
return pytorch_datafile
print('----------\n[ setting up pytorch data. ]\n----------')
print('----------\n[ setting up pytorch data, saving to {}. ]\n----------'.format(pytorch_datafile))

num_eps = 0
num_exs = 0
Expand All @@ -85,7 +86,6 @@ def build_data(opt):
include_labels = opt.get('include_labels', True)
context_length = opt.get('context_length', -1)
context = deque(maxlen=context_length if context_length > 0 else None)
preprocess = opt.get('pytorch_preprocess', True)
# pass examples to dictionary
with open(pytorch_datafile, 'w') as pytorch_data:
while not world_data.epoch_done():
Expand Down Expand Up @@ -130,9 +130,14 @@ def main():
help=('The file to be loaded, preprocessed, and saved'))
build.add_argument('--pytorch_buildteacher', type=str, default='',
help='Which teacher to use when building the pytorch data')
build.add_argument('--pytorch_preprocess', type=bool, default=True,
help='Whether the agent should preprocess the data while building'
preprocess = argparser.add_mutually_exclusive_group(required=False)
preprocess.add_argument('--pytorch_preprocess', dest='pytorch_preprocess', action='store_true',
help='Set if the agent should preprocess the data while building'
'the pytorch data')
preprocess.add_argument('--no_pytorch_preprocess', dest='pytorch_preprocess', action='store_false',
help='Set if the agent should NOT preprocess the data while building'
'the pytorch data')
argparser.set_defaults(pytorch_preprocess=True)
opt = argparser.parse_args()
build_data(opt)

Expand Down
12 changes: 9 additions & 3 deletions parlai/core/pytorch_data_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,14 @@ def add_cmdline_args(argparser):
help='how many workers the Pytorch dataloader should use')
arg_group.add_argument('--pytorch_buildteacher', type=str, default='',
help='Which teacher to use when building the pytorch data')
arg_group.add_argument('--pytorch_preprocess', type=bool, default=True,
help='Whether the agent should preprocess the data while building'
preprocess = argparser.add_mutually_exclusive_group(required=False)
preprocess.add_argument('--pytorch_preprocess', dest='pytorch_preprocess', action='store_true',
help='Set if the agent should preprocess the data while building'
'the pytorch data')
preprocess.add_argument('--no_pytorch_preprocess', dest='pytorch_preprocess', action='store_false',
help='Set if the agent should NOT preprocess the data while building'
'the pytorch data')
argparser.set_defaults(pytorch_preprocess=True)

def __init__(self, opt, shared=None):
opt['batch_sort'] = False
Expand All @@ -170,7 +175,8 @@ def __init__(self, opt, shared=None):
collate_fn=collate_fn,
pin_memory=False,
drop_last=False,
timeout=0)
# timeout=0
)
self.lastYs = [None] * self.bsz
else:
self.dataset = shared['dataset']
Expand Down

0 comments on commit 6cd52d8

Please sign in to comment.