forked from mila-iqia/babyai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
enjoy.py
executable file
·120 lines (96 loc) · 3.69 KB
/
enjoy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
"""
Visualize the performance of a model on a given environment.
"""
import argparse
import gym
import time
import babyai.utils as utils
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--env", required=True,
help="name of the environment to be run (REQUIRED)")
parser.add_argument("--model", default=None,
help="name of the trained model (REQUIRED or --demos-origin or --demos REQUIRED)")
parser.add_argument("--demos", default=None,
help="demos filename (REQUIRED or --model demos-origin required)")
parser.add_argument("--demos-origin", default=None,
help="origin of the demonstrations: human | agent (REQUIRED or --model or --demos REQUIRED)")
parser.add_argument("--seed", type=int, default=None,
help="random seed (default: 0 if model agent, 1 if demo agent)")
parser.add_argument("--argmax", action="store_true", default=False,
help="action with highest probability is selected for model agent")
parser.add_argument("--pause", type=float, default=0.1,
help="the pause between two consequent actions of an agent")
parser.add_argument("--manual-mode", action="store_true", default=False,
help="Allows you to take control of the agent at any point of time")
args = parser.parse_args()
action_map = {
"LEFT" : "left",
"RIGHT" : "right",
"UP" : "forward",
"PAGE_UP": "pickup",
"PAGE_DOWN": "drop",
"SPACE": "toggle"
}
assert args.model is not None or args.demos is not None, "--model or --demos must be specified."
if args.seed is None:
args.seed = 0 if args.model is not None else 1
# Set seed for all randomness sources
utils.seed(args.seed)
# Generate environment
env = gym.make(args.env)
env.seed(args.seed)
global obs
obs = env.reset()
print("Mission: {}".format(obs["mission"]))
# Define agent
agent = utils.load_agent(env, args.model, args.demos, args.demos_origin, args.argmax, args.env)
# Run the agent
done = True
action = None
def keyDownCb(keyName):
global obs
# Avoiding processing of observation by agent for wrong key clicks
if keyName not in action_map and keyName != "RETURN":
return
agent_action = agent.act(obs)['action']
if keyName in action_map:
action = env.actions[action_map[keyName]]
elif keyName == "RETURN":
action = agent_action
obs, reward, done, _ = env.step(action)
agent.analyze_feedback(reward, done)
if done:
print("Reward:", reward)
obs = env.reset()
print("Mission: {}".format(obs["mission"]))
step = 0
episode_num = 0
while True:
time.sleep(args.pause)
renderer = env.render("human")
if args.manual_mode and renderer.window is not None:
renderer.window.setKeyDownCb(keyDownCb)
else:
result = agent.act(obs)
obs, reward, done, _ = env.step(result['action'])
agent.analyze_feedback(reward, done)
if 'dist' in result and 'value' in result:
dist, value = result['dist'], result['value']
dist_str = ", ".join("{:.4f}".format(float(p)) for p in dist.probs[0])
print("step: {}, mission: {}, dist: {}, entropy: {:.2f}, value: {:.2f}".format(
step, obs["mission"], dist_str, float(dist.entropy()), float(value)))
else:
print("step: {}, mission: {}".format(step, obs['mission']))
if done:
print("Reward:", reward)
episode_num += 1
env.seed(args.seed + episode_num)
obs = env.reset()
agent.on_reset()
step = 0
else:
step += 1
if renderer.window is None:
break