Skip to content

Commit

Permalink
synchronization of examples)
Browse files Browse the repository at this point in the history
  • Loading branch information
klshuster committed Dec 21, 2017
1 parent 6cd52d8 commit 660fb7c
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions parlai/core/pytorch_data_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
except Exception as e:
raise ModuleNotFoundError('Need to install Pytorch: go to pytorch.org')
from torch.utils.data import Dataset, DataLoader, sampler
from multiprocessing import Lock, RawArray
import ctypes


# Default collate function (for how to prepare a batch)
Expand All @@ -100,9 +102,10 @@ def __init__(self, opt):
self.data_gen = self._data_generator(self.datafile)
self.length_datafile = self.datafile + ".length"
self._load_lens()
self.indices_lock = None
self.indices_seen = None

def __getitem__(self, index):
# (ignore index because it is streaming data)
return next(self.data_gen)

def __len__(self):
Expand All @@ -122,20 +125,31 @@ def _data_generator(self, datafile):
def _read_episode(self, datafile):
read = open(datafile)
episode = []
for line in read:
for idx, line in enumerate(read):
with self.indices_lock:
if self.indices_seen[idx]:
continue
self.indices_seen[idx] = True
example = json.loads(line)
episode.append(example)
if example['episode_done']:
yield episode
episode = []
read.close()
with self.indices_lock:
for idx in range(len(self.indices_seen)):
self.indices_seen[idx] = False

def num_episodes(self):
return self.num_eps

def num_examples(self):
return self.num_exs

def set_sync(self, indices_lock, indices_seen):
self.indices_lock = indices_lock
self.indices_seen = indices_seen


class PytorchDataTeacher(FixedDialogTeacher):

Expand Down Expand Up @@ -166,6 +180,9 @@ def __init__(self, opt, shared=None):
collate_fn = opt.get('collate_fn', default_collate)
if not shared:
self.dataset = StreamDataset(opt)
self.indices_lock = Lock()
self.indices_seen = RawArray(ctypes.c_bool, self.num_episodes())
self.dataset.set_sync(self.indices_lock, self.indices_seen)
self.pytorch_dataloader = DataLoader(
self.dataset,
batch_size=self.bsz,
Expand Down

0 comments on commit 660fb7c

Please sign in to comment.