diff --git a/babyai/bot.py b/babyai/bot.py index 1e99789e..9c97065d 100644 --- a/babyai/bot.py +++ b/babyai/bot.py @@ -63,7 +63,7 @@ def get_action(self): subgoals. Returns an action. """ - pass + raise NotImplementedError() def simulate_step(self, action): @@ -143,11 +143,11 @@ def foolish_action_while_exploring(self, action): # need to open it self.bot.stack.append(OpenSubgoal(self.bot)) - def take_action(self, action): + def replan(self, given_action): """ Function that updates the bot's stack given the action played. Should be overridden in all sub-classes. - TODO: There are some steps that are common in both take_action and get_action + TODO: There are some steps that are common in both replan and get_action for certain subgoals, maybe the bot's speed can be improved if we do them only once - or at least we can factorize the code a bit @@ -155,19 +155,7 @@ def take_action(self, action): When it returns `False`, the bot continues replanning. """ - self.erroneous_box_opening(action) - self.simulate_step(action) - - def erroneous_box_opening(self, action): - """ - When the agent opens a box, we raise an error and mark the task unsolvable. - This is a tad conservative, because maybe the box is irrelevant to the mission. - TODO: We can relax this by checking if the opened box is crucial for the mission. - TODO: We can relax this by checking if a similar box still exists if the box is crucial. - """ - if action == self.actions.toggle and self.fwd_cell is not None and self.fwd_cell.type == 'box': - raise DisappearedBoxError('A box was opened. Too Bad :(') - + raise NotImplementedError() class OpenSubgoal(Subgoal): def get_action(self): @@ -206,9 +194,7 @@ def get_action(self): return self.actions.toggle - def take_action(self, action): - super().take_action(action) - + def replan(self, action): # CASE 1: The door is locked # we need to fetch the key and return # i.e. update the stack REGARDLESS of the action @@ -282,11 +268,10 @@ class DropSubgoal(Subgoal): def get_action(self): return self.actions.drop - def take_action(self, action): - super().take_action(action) - if action == self.actions.drop: + def replan(self, given_action): + if given_action == self.actions.drop: self.bot.stack.pop() - elif action in (self.actions.left, self.actions.right, self.actions.forward): + elif given_action in (self.actions.left, self.actions.right, self.actions.forward): # Go back to where you were to drop what you got self.bot.stack.append(GoNextToSubgoal(self.bot, tuple(self.fwd_pos))) # done/pickup actions won't have any effect -> Next step would take us to the same subgoal @@ -297,11 +282,10 @@ class PickupSubgoal(Subgoal): def get_action(self): return self.actions.pickup - def take_action(self, action): - super().take_action(action) - if action == self.actions.pickup: + def replan(self, given_action): + if given_action == self.actions.pickup: self.bot.stack.pop() - elif action in (self.actions.left, self.actions.right): + elif given_action in (self.actions.left, self.actions.right): # Go back to where you were to pickup what was in front of you self.bot.stack.append(GoNextToSubgoal(self.bot, tuple(self.fwd_pos))) # done/drop/forward actions won't have any effect -> Next step would take us to the same subgoal @@ -471,9 +455,7 @@ def closest_wall_or_door_given_dir(position, direction): return self.actions.left return self.actions.right - def take_action(self, action): - super().take_action(action) - + def replan(self, given_action): if isinstance(self.datum, ObjDesc): target_pos = self.bot.find_obj_pos(self.datum, self.adjacent) if not target_pos: @@ -486,9 +468,9 @@ def take_action(self, action): # CASE 1: The position we are on is the one we should go next to # -> Move away from it if manhattan_distance(target_pos, self.pos) == (1 if self.adjacent else 0): - if action in (self.actions.drop, self.actions.pickup, self.actions.toggle): + if given_action in (self.actions.drop, self.actions.pickup, self.actions.toggle): # Update the stack if we did something bad, or do nothing if the action doesn't change anything - self.foolish_action_while_exploring(action) + self.foolish_action_while_exploring(given_action) # Whatever other action, the stack should stay the same, and it's the new action that should be evaluated # TODO: Double check what happens in this scenario with all actions return True @@ -514,7 +496,7 @@ def take_action(self, action): # CASE 3: the action taken would lead us to face the target cell # -> Don't do anything. The stack will be popped at next step anyway if np.array_equal(target_pos, self.new_fwd_pos): - assert action in (self.actions.forward, self.actions.left, self.actions.right), "Doesn't make sense" + assert given_action in (self.actions.forward, self.actions.left, self.actions.right), "Doesn't make sense" return True # CASE 5: otherwise @@ -556,7 +538,7 @@ def take_action(self, action): drop_pos = self.bot.find_drop_pos() self.bot.stack.append(DropSubgoal(self.bot)) self.bot.stack.append(GoNextToSubgoal(self.bot, drop_pos)) - if action == self.actions.pickup: + if given_action == self.actions.pickup: return True else: self.bot.stack.append(PickupSubgoal(self.bot)) @@ -567,8 +549,8 @@ def take_action(self, action): # TODO: what if I pickup another blocker (and that's good) # CASE 5.4: If there is nothing blocking us and we drop/pickup/toggle something for no reason - if action in (self.actions.drop, self.actions.pickup, self.actions.toggle): - self.foolish_action_while_exploring(action) + if given_action in (self.actions.drop, self.actions.pickup, self.actions.toggle): + self.foolish_action_while_exploring(given_action) # CASE 5.5: If we are GoingNextTo something because of exploration if self.reason == 'Explore': @@ -620,9 +602,7 @@ def unopened_door(pos, cell): assert False, "0nothing left to explore" - def take_action(self, action): - super().take_action(action) - + def replan(self, given_action): # Find the closest unseen position _, unseen_pos, with_blockers = self.bot.shortest_path( lambda pos, cell: not self.bot.vis_mask[pos], @@ -1052,33 +1032,50 @@ def get_action(self): return self.bot.mission.actions.done self.stack.pop() - def take_action(self, action): + def replan(self, given_action): """ - Update agent's internal state. Should always be called after get_action() and before env.step() + Update the subgoal stack given the action that will be taken by the agent. + + The algorithm is as follows: the top subgoal's `replan` method called, + until a moment when a subgoal that has already been at the top of the stack + pops up. + """ self.step_count += 1 - finished_updating = False - while not finished_updating: - empty_stack = self.empty_stack_update(action) - if not empty_stack: - subgoal = self.stack[-1] - subgoal.update_agent_attributes() - finished_updating = subgoal.take_action(action) - else: - finished_updating = True + seen_stack_tops = set() + while True: + if not self.stack: + self.empty_stack_update() + break + + subgoal = self.stack[-1] + seen_stack_tops.add(subgoal) + + subgoal.update_agent_attributes() + subgoal.simulate_step(given_action) + if subgoal.replan(given_action): + break def empty_stack_update(self, action): pos = self.mission.agent_pos dir_vec = self.mission.dir_vec fwd_pos = pos + dir_vec + self.stack.append(GoNextToSubgoal(self, fwd_pos)) - if len(self.stack) == 0: - if action != self.mission.actions.done: - self.stack.append(GoNextToSubgoal(self, fwd_pos)) - return True + def check_erroneous_box_opening(self, action): + """ + When the agent opens a box, we raise an error and mark the task unsolvable. + This is a tad conservative, because maybe the box is irrelevant to the mission. + TODO: We can relax this by checking if the opened box is crucial for the mission. + TODO: We can relax this by checking if a similar box still exists if the box is crucial. + """ + if (action == self.actions.toggle + and self.fwd_cell is not None + and self.fwd_cell.type == 'box'): + raise DisappearedBoxError('A box was opened. Too Bad :(') def step(self): action = self.get_action() - self.take_action(action) + self.replan(action) diff --git a/babyai/utils/agent.py b/babyai/utils/agent.py index 595150e9..7992d0cf 100644 --- a/babyai/utils/agent.py +++ b/babyai/utils/agent.py @@ -149,7 +149,7 @@ def on_reset(self): def act(self, obs=None, update_internal_state=True, *args, **kwargs): action = self.bot.get_action() if update_internal_state: - self.bot.take_action(action) + self.bot.replan(action) return {'action': action} def analyze_feedback(self, reward, done): diff --git a/scripts/eval_bot.py b/scripts/eval_bot.py index cfe4f271..6b1228c3 100755 --- a/scripts/eval_bot.py +++ b/scripts/eval_bot.py @@ -138,7 +138,7 @@ else: optimal_actions.append(action) - expert.take_action(action) + expert.replan(action) obs, reward, done, info = mission.step(action) total_reward += reward diff --git a/scripts/gui.py b/scripts/gui.py index cd856686..f8a1dfbc 100755 --- a/scripts/gui.py +++ b/scripts/gui.py @@ -382,7 +382,7 @@ def stepEnv(self, action=None): if action is None: action = self.bot_advisor_action - self.bot_advisor_agent.bot.take_action(action) + self.bot_advisor_agent.bot.replan(action) obs, reward, done, info = self.env.step(action) self.showEnv(obs) diff --git a/scripts/train_intelligent_expert.py b/scripts/train_intelligent_expert.py index e8a0df5a..04bc68af 100755 --- a/scripts/train_intelligent_expert.py +++ b/scripts/train_intelligent_expert.py @@ -137,7 +137,7 @@ def generate_dagger_demos(env_name, seeds, fail_obss, fail_actions, mean_steps): assert check_obss_equality(obs, new_obs), "Observations {} of seed {} don't match".format(j, seeds[i]) mission = obs['mission'] action = agent.act(update_internal_state=False)['action'] - _ = agent.bot.take_action(fail_actions[i][j]) + _ = agent.bot.replan(fail_actions[i][j]) debug_info['actions'].append(fail_actions[i][j]) new_obs, reward, done, _ = env.step(fail_actions[i][j]) if done and reward > 0: