Skip to content

Commit

Permalink
batched validation (facebookresearch#57)
Browse files Browse the repository at this point in the history
supports batching over ordered data
  • Loading branch information
alexholdenmiller authored May 8, 2017
1 parent c09df72 commit 52b3f8d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 32 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ python examples/display_model.py -m ir_baseline -t "#moviedd-reddit" -dt valid

Train a simple cpu-based memory network on the "10k training examples" bAbI task 1 with 8 threads (python processes) using Hogwild (requires zmq and Lua Torch):
```bash
python examples/memnn_luatorch_cpu/full_task_train.py -t babi:task10k:1 -n 8
python examples/memnn_luatorch_cpu/full_task_train.py -t babi:task10k:1 -nt 8
```

Trains an attentive LSTM model on the SQuAD dataset with a batch size of 32 examples (pytorch and regex):
```bash
python examples/drqa/train.py -t squad -b 32
python examples/drqa/train.py -t squad -bs 32
```

## Requirements
Expand Down
4 changes: 2 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ python build_dict.py -t babi:task1k:1 --dict-savepath /tmp/dict.tsv

Train a simple cpu-based memory network on the "10k training examples" bAbI task 1 with 8 threads (python processes) using Hogwild (requires zmq and Lua Torch):
```bash
python memnn_luatorch_cpu/full_task_train.py -t babi:task10k:1 -n 8
python memnn_luatorch_cpu/full_task_train.py -t babi:task10k:1 -nt 8
```

Trains an attentive LSTM model on the SQuAD dataset with a batch size of 32 examples (pytorch and regex):
```bash
python drqa/train.py -t squad -b 32
python drqa/train.py -t squad -bs 32
```
28 changes: 23 additions & 5 deletions parlai/core/dialog_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,20 @@ def __init__(self, opt, shared=None):
else:
self.metrics = Metrics(opt)

# for ordered data in batch mode (especially, for validation and
# testing), each teacher in the batch gets a start index and a step
# size so they all process disparate sets of the data
self.step_size = opt.get('batchsize', 1)
self.data_offset = opt.get('batchindex', 0)

self.reset()

def reset(self):
# Reset the dialog so that it is at the start of the epoch,
# and all metrics are reset.
self.metrics.clear()
self.lastY = None
self.episode_idx = -1
self.episode_idx = self.data_offset - self.step_size
self.epochDone = False
self.episode_done = True

Expand Down Expand Up @@ -105,21 +111,33 @@ def observe(self, observation):
self.lastLabelCandidates = None

def next_example(self):
num_eps = self.data.num_episodes()
if self.episode_done:
num_eps = self.data.num_episodes()
if self.random:
# select random episode
self.episode_idx = random.randrange(num_eps)
else:
# select next episode
self.episode_idx = (self.episode_idx + 1) % num_eps
self.episode_idx = (self.episode_idx + self.step_size) % num_eps
self.entry_idx = 0
else:
self.entry_idx += 1
return self.data.get(self.episode_idx, self.entry_idx)

action, epoch_done = self.data.get(self.episode_idx, self.entry_idx)

if self.random:
epoch_done = False
elif (self.episode_idx + self.step_size >= num_eps and
action['episode_done']):
# this is used for ordered data to check whether there's more data
epoch_done = True

return action, epoch_done

def act(self):
"""Send new dialog message. """
"""Send new dialog message."""
if self.epochDone:
return { 'episode_done': True }
action, self.epochDone = self.next_example()
self.episode_done = action['episode_done']
action['id'] = self.getID()
Expand Down
2 changes: 1 addition & 1 deletion parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, add_parlai_args=True, add_model_args=False):
self.parlai_home = (os.path.dirname(os.path.dirname(os.path.dirname(
os.path.realpath(__file__)))))
os.environ['PARLAI_HOME'] = self.parlai_home

if add_parlai_args:
self.add_parlai_args()
if add_model_args:
Expand Down
61 changes: 39 additions & 22 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, opt, agents=None, shared=None):
self.opt = copy.deepcopy(opt)
if shared:
# Create agents based on shared data.
self.agents = create_agents_from_shared(shared.agents)
self.agents = create_agents_from_shared(shared['agents'])
else:
# Add passed in agents to world directly.
self.agents = agents
Expand Down Expand Up @@ -131,7 +131,7 @@ def episode_done(self):
return False

def epoch_done(self):
"""Whether the epoch is done or not.
"""Whether the epoch is done or not.
Not all worlds have the notion of an epoch, but this is useful
for fixed training, validation or test sets.
"""
Expand Down Expand Up @@ -217,15 +217,17 @@ def parley(self):
acts[1] = agents[1].act()
agents[0].observe(validate(acts[1]))

def epoch_done(self):
""" Only the first agent indicates when the epoch is done."""
return (self.agents[0].epoch_done()
if hasattr(self.agents[0], 'epoch_done') else False)

def episode_done(self):
""" Only the first agent indicates when the episode is done."""
if self.acts[0] is not None:
return self.acts[0].get('episode_done', False)
else:
return False

def epoch_done(self):
"""Only the first agent indicates when the epoch is done."""
return (self.agents[0].epoch_done()
if hasattr(self.agents[0], 'epoch_done') else False)

def report(self):
return self.agents[0].report()
Expand Down Expand Up @@ -407,6 +409,25 @@ def report(self):
return m


def override_opts_in_shared(table, overrides):
"""Looks recursively for opt dictionaries within shared dict and overrides
any key-value pairs with pairs from the overrides dict.
"""
if 'opt' in table:
# change values if an 'opt' dict is available
for k, v in overrides.items():
table['opt'][k] = v
for k, v in table.items():
# look for sub-dictionaries which also might contain an 'opt' dict
if type(v) == dict and k != 'opt':
override_opts_in_shared(v, overrides)
elif type(v) == list:
for item in v:
if type(item) == dict:
override_opts_in_shared(item, overrides)
return table


class BatchWorld(World):
"""Creates a separate world for each item in the batch, sharing
the parameters for each.
Expand All @@ -417,17 +438,15 @@ class BatchWorld(World):
def __init__(self, opt, world):
self.opt = opt
self.random = opt.get('datatype', None) == 'train'
if not self.random:
raise NotImplementedError(
'Ordered data not implemented yet in batch mode.')

self.world = world
shared = world.share()
self.worlds = []
for i in range(opt['batchsize']):
opti = copy.deepcopy(opt)
opti['batchindex'] = i
self.worlds.append(shared['world_class'](opti, None, shared))
# make sure that any opt dicts in shared have batchindex set to i
# this lets all shared agents know which batchindex they have,
# which is needed for ordered data (esp valid/test sets)
override_opts_in_shared(shared, { 'batchindex': i })
self.worlds.append(shared['world_class'](opt, None, shared))
self.batch_observations = [ None ] * len(self.worlds)

def __iter__(self):
Expand All @@ -448,7 +467,7 @@ def batch_act(self, index, batch_observation):
# Call update on agent
a = self.world.get_agents()[index]
if (batch_observation is not None and len(batch_observation) > 0 and
hasattr(a, 'batch_act')):
hasattr(a, 'batch_act')):
batch_reply = a.batch_act(batch_observation)
# Store the actions locally in each world.
for w in self.worlds:
Expand All @@ -461,7 +480,7 @@ def batch_act(self, index, batch_observation):
agents = w.get_agents()
acts = w.get_acts()
acts[index] = agents[index].act()
batch_reply.append(acts[index])
batch_reply.append(acts[index])
return batch_reply

def parley(self):
Expand All @@ -473,7 +492,7 @@ def parley(self):
for w in self.worlds:
if hasattr(w, 'parley_init'):
w.parley_init()

for index in range(num_agents):
batch_act = self.batch_act(index, batch_observations[index])
for other_index in range(num_agents):
Expand All @@ -486,8 +505,6 @@ def display(self):
for i, w in enumerate(self.worlds):
s += ("[batch world " + str(i) + ":]\n")
s += (w.display() + '\n')
if not self.random and w.epoch_done():
break
s += ("[--end of batch--]")
return s

Expand All @@ -499,9 +516,9 @@ def episode_done(self):

def epoch_done(self):
for world in self.worlds:
if world.epoch_done():
return True
return False
if not world.epoch_done():
return False
return True

def report(self):
return self.worlds[0].report()
Expand Down

0 comments on commit 52b3f8d

Please sign in to comment.