Skip to content

Commit

Permalink
fix script for demo evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
saleml committed Oct 18, 2018
1 parent f2f5c1b commit 8ea10b0
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import datetime

import babyai.utils as utils
from babyai.evaluate import evaluate, batch_evaluate
from babyai.evaluate import evaluate_demo_agent, batch_evaluate, evaluate
# Parse arguments

parser = argparse.ArgumentParser()
Expand All @@ -30,6 +30,9 @@
help="action with highest probability is selected for model agent")
parser.add_argument("--contiguous-episodes", action="store_true", default=False,
help="Make sure episodes on which evaluation is done are contiguous")
parser.add_argument("--worst-episodes-to-show", type=int, default=10,
help="The number of worse episodes to show")


def main(args, seed, episodes):
# Set seed for all randomness sources
Expand All @@ -45,10 +48,13 @@ def main(args, seed, episodes):
episodes = len(agent.demos)

# Evaluate
if isinstance(agent, utils.ModelAgent) and not args.contiguous_episodes:
logs = batch_evaluate(agent, args.env, seed, episodes)
if isinstance(agent, utils.ModelAgent):
if not args.contiguous_episodes:
logs = batch_evaluate(agent, args.env, seed, episodes)
else:
logs = evaluate(agent, env, episodes, False)
else:
logs = evaluate(agent, env, episodes, False)
logs = evaluate_demo_agent(agent, episodes)

return logs

Expand All @@ -69,19 +75,31 @@ def main(args, seed, episodes):
fps = num_frames/(end_time - start_time)
ellapsed_time = int(end_time - start_time)
duration = datetime.timedelta(seconds=ellapsed_time)
return_per_episode = utils.synthesize(logs["return_per_episode"])
success_per_episode = utils.synthesize(
[1 if r > 0 else 0 for r in logs["return_per_episode"]])

if args.model is not None:
return_per_episode = utils.synthesize(logs["return_per_episode"])
success_per_episode = utils.synthesize(
[1 if r > 0 else 0 for r in logs["return_per_episode"]])

num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"])

print("F {} | FPS {:.0f} | D {} | R:xsmM {:.2f} {:.2f} {:.2f} {:.2f} | S {:.2f} | F:xsmM {:.1f} {:.1f} {} {}"
.format(num_frames, fps, duration,
*return_per_episode.values(),
success_per_episode['mean'],
*num_frames_per_episode.values()))

indexes = sorted(range(len(logs["return_per_episode"])), key=lambda k: logs["return_per_episode"][k])
n = 10
print("{} worst episodes:".format(n))
for i in indexes[:n]:
print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i]))
if args.model is not None:
print("F {} | FPS {:.0f} | D {} | R:xsmM {:.2f} {:.2f} {:.2f} {:.2f} | S {:.2f} | F:xsmM {:.1f} {:.1f} {} {}"
.format(num_frames, fps, duration,
*return_per_episode.values(),
success_per_episode['mean'],
*num_frames_per_episode.values()))
else:
print("F {} | FPS {:.0f} | D {} | F:xsmM {:.1f} {:.1f} {} {}"
.format(num_frames, fps, duration, *num_frames_per_episode.values()))

indexes = sorted(range(len(logs["num_frames_per_episode"])), key=lambda k: - logs["num_frames_per_episode"][k])

n = args.worst_episodes_to_show
if n > 0:
print("{} worst episodes:".format(n))
for i in indexes[:n]:
if args.model is not None:
print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i]))
else:
print("- episode {}: F={}".format(i, logs["num_frames_per_episode"][i]))

0 comments on commit 8ea10b0

Please sign in to comment.