forked from Talendar/super_mario_dqn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
125 lines (103 loc) · 4.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
import acme
import tensorflow as tf
import sonnet as snt
from acme.tf.networks import duelling
from acme.utils import loggers
import matplotlib.pyplot as plt
import numpy as np
from mario_env import MarioEnvironment
from utils.env_loop import EnvironmentLoop
from utils.utils import restore_module
import pickle
from dqn.agent import DQN as DQNAgent
from utils import visualize_policy
def make_dqn(num_actions: int):
return snt.Sequential([
snt.Conv2D(32, [3, 3], [2, 2]),
tf.nn.relu,
snt.Conv2D(32, [3, 3], [2, 2]),
tf.nn.relu,
snt.Conv2D(32, [3, 3], [2, 2]),
tf.nn.relu,
snt.Conv2D(32, [3, 3], [2, 2]),
tf.nn.relu,
snt.Flatten(),
duelling.DuellingMLP(num_actions, hidden_sizes=[512]),
])
def make_env(colorful_rendering: bool = False):
return MarioEnvironment(skip_frames=3,
img_rescale_pc=0.5,
stack_func=np.dstack,
stack_mode="all",
grayscale=True,
black_background=True,
in_game_score_weight=0.02,
movement_type="right_only",
world_and_level=(3, 1),
idle_frames_threshold=1000,
colorful_rendering=colorful_rendering)
def train(network=None, expert_data_path=None):
env = make_env()
env_spec = acme.make_environment_spec(env)
if network is None:
network = make_dqn(env_spec.actions.num_values)
expert_data = None
if expert_data_path is not None:
with open(expert_data_path, "rb") as handle:
expert_data = pickle.load(handle)
num_timesteps = np.sum([1 + len(ep["mid"]) for ep in expert_data])
print(f"Using expert data from {expert_data_path}. "
f"Episodes: {len(expert_data)}. Timesteps: {num_timesteps}.")
agent = DQNAgent(environment_spec=env_spec,
network=network,
batch_size=32,
learning_rate=6.25e-5,
logger=loggers.NoOpLogger(),
min_replay_size=2500,
max_replay_size=int(2e5),
target_update_period=2500,
epsilon=tf.Variable(0.015),
n_step=10,
discount=0.9,
expert_data=expert_data)
loop = EnvironmentLoop(environment=env,
actor=agent,
module2save=network)
reward_history = loop.run(num_steps=int(1e5),
render=True,
checkpoint=True,
checkpoint_freq=10)
avg_hist = [np.mean(reward_history[i:(i+50)])
for i in range(len(reward_history) - 50)]
plt.plot(list(range(len(avg_hist))), avg_hist)
plt.show()
env.close()
return network
if __name__ == "__main__":
# collect_data_from_human()
# policy_path = find_best_policy("checkpoints/checkpoints_2021-03-29-20-36-19")
policy_path = "checkpoints/best_policies/w3_lv1/w3_lv1_completed_r3175"
policy_network = make_dqn(make_env().action_spec().num_values)
restore_module(base_module=policy_network, save_path=policy_path)
print(f"\nUsing policy checkpoint from: {policy_path}")
# train(policy_network, expert_data_path=None)
input("\nPress [ENTER] to continue.")
visualize_policy(policy_network, env=make_env(colorful_rendering=True),
num_episodes=3, fps=60, epsilon_greedy=0)
# DEBUG
# env = make_env()
# obs = env.reset().observation
# while True:
# env.render()
# env.plot_obs(np.hstack([obs[:, :, i] for i in range(obs.shape[-1])]))
# obs = env.step(
# np.random.randint(low=0, high=env.action_spec().num_values)
# ).observation
# env = make_env()
# print(env.reset().observation.shape)
# network = make_dqn(env.action_spec().num_values)
# out = network(tf.expand_dims(env.reset().observation, axis=0))
# print(out)
# print(out.shape)