Skip to content

Commit

Permalink
Merge pull request facebookresearch#62 from facebookresearch/batch_fix
Browse files Browse the repository at this point in the history
fixes to BatchWorld
  • Loading branch information
ajfisch authored May 10, 2017
2 parents fad2334 + d6988b4 commit d9f0602
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 17 deletions.
1 change: 0 additions & 1 deletion examples/drqa/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
def main(opt):
# Check options
opt['datatype'] = 'valid'
opt['batchsize'] = 1
assert('pretrained_model' in opt)

# Load document reader
Expand Down
1 change: 0 additions & 1 deletion examples/drqa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def build_dict(opt):
def validate(opt, agent, n_iter):
opt = copy.deepcopy(opt)
opt['datatype'] = 'valid'
opt['batchsize'] = 1
valid_world = create_task(opt, agent)

logger.info('[ Running validation... ]')
Expand Down
5 changes: 5 additions & 0 deletions parlai/agents/drqa/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def observe(self, observation):
observation['text'] = '\n'.join(dialogue)
self.observation = observation
self.episode_done = observation['episode_done']
return observation

def act(self):
"""Update or predict on a single example (batchsize = 1)."""
Expand Down Expand Up @@ -227,6 +228,10 @@ def _build_ex(self, ex):
"""Find the token span of the answer in the context for this example.
If a token span cannot be found, return None. Otherwise, torchify.
"""
# Check if empty input (end of epoch)
if not 'text' in ex:
return

# Split out document + question
inputs = {}
fields = ex['text'].split('\n')
Expand Down
6 changes: 4 additions & 2 deletions parlai/core/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, opt, shared=None):

def observe(self, observation):
self.observation = observation
return observation

def act(self):
"""Return state/action table based upon given observation."""
Expand Down Expand Up @@ -266,8 +267,8 @@ def __next__(self):
if self.epoch_done():
raise StopIteration()

def observe(self, obs):
self.tasks[self.task_idx].observe(obs)
def observe(self, observation):
self.tasks[self.task_idx].observe(observation)
if self.new_task:
self.new_task = False
if self.random:
Expand All @@ -281,6 +282,7 @@ def observe(self, obs):
start_idx != self.task_idx)
if start_idx == self.task_idx:
return {'text': 'There are no more examples remaining.'}
return observation

def act(self):
t = self.tasks[self.task_idx].act()
Expand Down
1 change: 1 addition & 0 deletions parlai/core/dialog_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def observe(self, observation):
obs, self.lastY, self.lastLabelCandidates)
self.lastY = None
self.lastLabelCandidates = None
return observation

def next_example(self):
num_eps = self.data.num_episodes()
Expand Down
24 changes: 14 additions & 10 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def __init__(self, opt, world):
# which is needed for ordered data (esp valid/test sets)
override_opts_in_shared(shared, { 'batchindex': i })
self.worlds.append(shared['world_class'](opt, None, shared))
self.batch_observations = [ None ] * len(self.worlds)
self.batch_observations = [ None ] * len(self.world.get_agents())

def __iter__(self):
return self
Expand All @@ -456,32 +456,36 @@ def __next__(self):
if self.epoch_done():
raise StopIteration()

def batch_observe(self, index, batch):
for w in self.worlds:
def batch_observe(self, index, batch_actions):
batch_observations = []
for i, w in enumerate(self.worlds):
agents = w.get_agents()
agents[index].observe(validate(batch[index]))
return batch
observation = agents[index].observe(validate(batch_actions[i]))
if observation is None:
raise ValueError('Agents should return what they observed.')
batch_observations.append(observation)
return batch_observations

def batch_act(self, index, batch_observation):
# Given batch observation, do update for agents[index].
# Call update on agent
a = self.world.get_agents()[index]
if (batch_observation is not None and len(batch_observation) > 0 and
hasattr(a, 'batch_act')):
batch_reply = a.batch_act(batch_observation)
batch_actions = a.batch_act(batch_observation)
# Store the actions locally in each world.
for w in self.worlds:
acts = w.get_acts()
acts[index] = batch_reply[index]
acts[index] = batch_actions[index]
else:
# Reverts to running on each individually.
batch_reply = []
batch_actions = []
for w in self.worlds:
agents = w.get_agents()
acts = w.get_acts()
acts[index] = agents[index].act()
batch_reply.append(acts[index])
return batch_reply
batch_actions.append(acts[index])
return batch_actions

def parley(self):
# Collect batch together for each agent, and do update.
Expand Down
2 changes: 1 addition & 1 deletion parlai/mturk/tasks/model_evaluator/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def observe(self, observation):
# The rating given by turker
# Because we only have one turn in this conversation, we don't need to track turn_index
# print(self.observation)
pass
return observation

def act(self):
# All agents act once in the world
Expand Down
5 changes: 3 additions & 2 deletions parlai/mturk/tasks/qa_data_collection/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, opt, shared=None):
self.opt = copy.deepcopy(opt)
self.id = 'QA Collector'
self.turn_index = -1

# Initialize a SQuAD teacher agent, which we will later get context from
module_name = 'parlai.tasks.squad.agents'
class_name = 'DefaultTeacher'
Expand All @@ -40,6 +40,7 @@ def observe(self, observation):
# Turker's answer, from the second turn
# print(self.observation)
pass
return observation

def act(self):
self.turn_index = (self.turn_index + 1) % 2; # Each turn starts from the QA Collector agent
Expand All @@ -55,7 +56,7 @@ def act(self):
context = '\n'.join(qa['text'].split('\n')[:-1])

# Wrap the context with a prompt telling the turker what to do next
ad['text'] = (context +
ad['text'] = (context +
'\n\nPlease provide a question given this context.')

if self.turn_index == 1:
Expand Down

0 comments on commit d9f0602

Please sign in to comment.