Skip to content

Commit

Permalink
Merge pull request mila-iqia#45 from mila-udem/dima-dev
Browse files Browse the repository at this point in the history
Fix two crucial bugs
  • Loading branch information
rizar authored Feb 4, 2019
2 parents 69ea2b3 + 07f2133 commit b12430f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
1 change: 1 addition & 0 deletions babyai/imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def initial_status():
if torch.cuda.is_available():
self.acmodel.cpu()
utils.save_model(self.acmodel, self.args.model)
self.obss_preprocessor.vocab.save()
if torch.cuda.is_available():
self.acmodel.cuda()
with open(status_path, 'w') as dst:
Expand Down
11 changes: 6 additions & 5 deletions scripts/make_agent_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,16 @@ def generate_demos(n_episodes, valid, seed, shift=0):
checkpoint_time = time.time()

just_crashed = False
for i in range(n_episodes):
# Run the expert for one episode
while True:
if len(demos) == n_episodes:
break

done = False
if just_crashed:
logger.info("reset the environment to find a mission that the bot can solve")
env.reset()
else:
env.seed(seed + i)
env.seed(seed + len(demos))
obs = env.reset()
agent.on_reset()

Expand Down Expand Up @@ -140,15 +141,15 @@ def generate_demos(n_episodes, valid, seed, shift=0):
if args.save_interval > 0 and len(demos) < n_episodes and len(demos) % args.save_interval == 0:
logger.info("Saving demos...")
utils.save_demos(demos, demos_path)
logger.info("Demos saved")
logger.info("{} demos saved".format(len(demos)))
# print statistics for the last 100 demonstrations
print_demo_lengths(demos[-100:])


# Save demonstrations
logger.info("Saving demos...")
utils.save_demos(demos, demos_path)
logger.info("Demos saved")
logger.info("{} demos saved".format(len(demos)))
print_demo_lengths(demos[-100:])


Expand Down

0 comments on commit b12430f

Please sign in to comment.