Skip to content

Commit

Permalink
more stats and the cheapest possible speedup - don't use
Browse files Browse the repository at this point in the history
numpy.array_equal when not required!
  • Loading branch information
rizar committed Jan 23, 2019
1 parent 225dc34 commit 7073870
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
15 changes: 9 additions & 6 deletions babyai/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def get_action(self):
# No path found -> Explore the world
return ExploreSubgoal(self.bot).get_action()
else:
target_pos = self.datum
target_pos = tuple(self.datum)

# CASE 1: The position we are on is the one we should go next to
# -> Move away from it
Expand All @@ -377,7 +377,7 @@ def get_action(self):
# CASE 3: we are still far from the target
# Try to find a non-blocker path
path, _, _ = self.bot.shortest_path(
lambda pos, cell: np.array_equal(pos, target_pos)
lambda pos, cell: pos == target_pos,
)

# CASE 3.1: No non-blocker path found, and reexploration is allowed
Expand All @@ -399,7 +399,7 @@ def get_action(self):
# -> Look for blocker paths
if not path:
path, _, _ = self.bot.shortest_path(
lambda pos, cell: np.array_equal(pos, target_pos),
lambda pos, cell: pos == target_pos,
try_with_blockers=True
)

Expand Down Expand Up @@ -468,7 +468,7 @@ def take_action(self, action):
self.bot.stack.append(ExploreSubgoal(self.bot))
return False
else:
target_pos = self.datum
target_pos = tuple(self.datum)

# CASE 1: The position we are on is the one we should go next to
# -> Move away from it
Expand Down Expand Up @@ -516,7 +516,7 @@ def take_action(self, action):
# CASE 5: otherwise
# Try to find a path
path, _, _ = self.bot.shortest_path(
lambda pos, cell: np.array_equal(pos, target_pos),
lambda pos, cell: pos == target_pos,
try_with_blockers=True
)

Expand Down Expand Up @@ -699,6 +699,7 @@ def __init__(self, mission, timeout=10000):
self.process_instr(mission.instrs)

self.bfs_counter = 0
self.bfs_step_counter = 0

def find_obj_pos(self, obj_desc, adjacent=False):
"""
Expand Down Expand Up @@ -791,6 +792,8 @@ def breadth_first_search(self, initial_states, accept_fn, ignore_blockers):
going straight over turning.
"""
self.bfs_counter += 1

queue = [(state, None) for state in initial_states]
grid = self.mission.grid
previous_pos = dict()
Expand All @@ -803,7 +806,7 @@ def breadth_first_search(self, initial_states, accept_fn, ignore_blockers):
if (i, j) in previous_pos:
continue

self.bfs_counter += 1
self.bfs_step_counter += 1

cell = grid.get(i, j)
previous_pos[(i, j)] = prev_pos
Expand Down
11 changes: 8 additions & 3 deletions scripts/eval_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@
num_success = 0
total_reward = 0
total_steps = []
total_bfs = 0
total_episode_steps = 0
total_bfs_steps = 0

for run_no in range(options.num_runs):
Expand Down Expand Up @@ -143,7 +145,9 @@
episode_steps += 1

if done:
total_bfs_steps += expert.bfs_counter
total_episode_steps += episode_steps
total_bfs_steps += expert.bfs_step_counter
total_bfs += expert.bfs_counter
if reward > 0:
num_success += 1
total_steps.append(episode_steps)
Expand Down Expand Up @@ -172,5 +176,6 @@
print('total time: %.1fs' % total_time)
if not all_good:
raise Exception("some tests failed")
print(total_bfs_steps)

print('total episode_steps:', total_episode_steps)
print('total bfs:', total_bfs)
print('total bfs steps:', total_bfs_steps)

0 comments on commit 7073870

Please sign in to comment.