forked from mila-iqia/babyai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
enjoy.py
executable file
·131 lines (103 loc) · 3.86 KB
/
enjoy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3
"""
Visualize the performance of a model on a given environment.
"""
import argparse
import gym
import time
import babyai.utils as utils
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--env", required=True,
help="name of the environment to be run (REQUIRED)")
parser.add_argument("--model", default=None,
help="name of the trained model (REQUIRED or --demos-origin or --demos REQUIRED)")
parser.add_argument("--demos", default=None,
help="demos filename (REQUIRED or --model demos-origin required)")
parser.add_argument("--demos-origin", default=None,
help="origin of the demonstrations: human | agent (REQUIRED or --model or --demos REQUIRED)")
parser.add_argument("--seed", type=int, default=None,
help="random seed (default: 0 if model agent, 1 if demo agent)")
parser.add_argument("--argmax", action="store_true", default=False,
help="action with highest probability is selected for model agent")
parser.add_argument("--pause", type=float, default=0.1,
help="the pause between two consequent actions of an agent")
parser.add_argument("--manual-mode", action="store_true", default=False,
help="Allows you to take control of the agent at any point of time")
args = parser.parse_args()
action_map = {
"left" : "left",
"right" : "right",
"up" : "forward",
"p" : "pickup",
"pageup" : "pickup",
"d" : "drop",
"pagedown" : "drop",
" " : "toggle"
}
assert args.model is not None or args.demos is not None, "--model or --demos must be specified."
if args.seed is None:
args.seed = 0 if args.model is not None else 1
# Set seed for all randomness sources
utils.seed(args.seed)
# Generate environment
env = gym.make(args.env)
env.seed(args.seed)
global obs
obs = env.reset()
print("Mission: {}".format(obs["mission"]))
# Define agent
agent = utils.load_agent(env, args.model, args.demos, args.demos_origin, args.argmax, args.env)
# Run the agent
done = True
action = None
def keyDownCb(event):
global obs
keyName = event.key
print(keyName)
# Avoiding processing of observation by agent for wrong key clicks
if keyName not in action_map and keyName != "enter":
return
agent_action = agent.act(obs)['action']
# Map the key to an action
if keyName in action_map:
action = env.actions[action_map[keyName]]
# Enter executes the agent's action
elif keyName == "enter":
action = agent_action
obs, reward, done, _ = env.step(action)
agent.analyze_feedback(reward, done)
if done:
print("Reward:", reward)
obs = env.reset()
print("Mission: {}".format(obs["mission"]))
if args.manual_mode:
env.render('human')
env.window.reg_key_handler(keyDownCb)
step = 0
episode_num = 0
while True:
time.sleep(args.pause)
env.render("human")
if not args.manual_mode:
result = agent.act(obs)
obs, reward, done, _ = env.step(result['action'])
agent.analyze_feedback(reward, done)
if 'dist' in result and 'value' in result:
dist, value = result['dist'], result['value']
dist_str = ", ".join("{:.4f}".format(float(p)) for p in dist.probs[0])
print("step: {}, mission: {}, dist: {}, entropy: {:.2f}, value: {:.2f}".format(
step, obs["mission"], dist_str, float(dist.entropy()), float(value)))
else:
print("step: {}, mission: {}".format(step, obs['mission']))
if done:
print("Reward:", reward)
episode_num += 1
env.seed(args.seed + episode_num)
obs = env.reset()
agent.on_reset()
step = 0
else:
step += 1
if env.window.closed:
break