From 707387001a4029bf87012089d797683222136e16 Mon Sep 17 00:00:00 2001 From: rizar Date: Wed, 23 Jan 2019 09:15:12 -0500 Subject: [PATCH] more stats and the cheapest possible speedup - don't use numpy.array_equal when not required! --- babyai/bot.py | 15 +++++++++------ scripts/eval_bot.py | 11 ++++++++--- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/babyai/bot.py b/babyai/bot.py index f5ab4df0..fdbfb063 100644 --- a/babyai/bot.py +++ b/babyai/bot.py @@ -351,7 +351,7 @@ def get_action(self): # No path found -> Explore the world return ExploreSubgoal(self.bot).get_action() else: - target_pos = self.datum + target_pos = tuple(self.datum) # CASE 1: The position we are on is the one we should go next to # -> Move away from it @@ -377,7 +377,7 @@ def get_action(self): # CASE 3: we are still far from the target # Try to find a non-blocker path path, _, _ = self.bot.shortest_path( - lambda pos, cell: np.array_equal(pos, target_pos) + lambda pos, cell: pos == target_pos, ) # CASE 3.1: No non-blocker path found, and reexploration is allowed @@ -399,7 +399,7 @@ def get_action(self): # -> Look for blocker paths if not path: path, _, _ = self.bot.shortest_path( - lambda pos, cell: np.array_equal(pos, target_pos), + lambda pos, cell: pos == target_pos, try_with_blockers=True ) @@ -468,7 +468,7 @@ def take_action(self, action): self.bot.stack.append(ExploreSubgoal(self.bot)) return False else: - target_pos = self.datum + target_pos = tuple(self.datum) # CASE 1: The position we are on is the one we should go next to # -> Move away from it @@ -516,7 +516,7 @@ def take_action(self, action): # CASE 5: otherwise # Try to find a path path, _, _ = self.bot.shortest_path( - lambda pos, cell: np.array_equal(pos, target_pos), + lambda pos, cell: pos == target_pos, try_with_blockers=True ) @@ -699,6 +699,7 @@ def __init__(self, mission, timeout=10000): self.process_instr(mission.instrs) self.bfs_counter = 0 + self.bfs_step_counter = 0 def find_obj_pos(self, obj_desc, adjacent=False): """ @@ -791,6 +792,8 @@ def breadth_first_search(self, initial_states, accept_fn, ignore_blockers): going straight over turning. """ + self.bfs_counter += 1 + queue = [(state, None) for state in initial_states] grid = self.mission.grid previous_pos = dict() @@ -803,7 +806,7 @@ def breadth_first_search(self, initial_states, accept_fn, ignore_blockers): if (i, j) in previous_pos: continue - self.bfs_counter += 1 + self.bfs_step_counter += 1 cell = grid.get(i, j) previous_pos[(i, j)] = prev_pos diff --git a/scripts/eval_bot.py b/scripts/eval_bot.py index a8d36206..39c04ccc 100755 --- a/scripts/eval_bot.py +++ b/scripts/eval_bot.py @@ -98,6 +98,8 @@ num_success = 0 total_reward = 0 total_steps = [] + total_bfs = 0 + total_episode_steps = 0 total_bfs_steps = 0 for run_no in range(options.num_runs): @@ -143,7 +145,9 @@ episode_steps += 1 if done: - total_bfs_steps += expert.bfs_counter + total_episode_steps += episode_steps + total_bfs_steps += expert.bfs_step_counter + total_bfs += expert.bfs_counter if reward > 0: num_success += 1 total_steps.append(episode_steps) @@ -172,5 +176,6 @@ print('total time: %.1fs' % total_time) if not all_good: raise Exception("some tests failed") -print(total_bfs_steps) - +print('total episode_steps:', total_episode_steps) +print('total bfs:', total_bfs) +print('total bfs steps:', total_bfs_steps)