-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add separate training and evaluating scripts
- Loading branch information
1 parent
15c5fe5
commit 4da7db1
Showing
2 changed files
with
243 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
""" | ||
Stable Baselines 3 evaluating script for F1Tenth Gym with wrapped environment | ||
""" | ||
|
||
import os | ||
import gym | ||
import time | ||
import glob | ||
import argparse | ||
import numpy as np | ||
|
||
from datetime import datetime | ||
|
||
from stable_baselines3 import PPO | ||
|
||
from code.wrappers import F110_Wrapped, RandomMap | ||
|
||
|
||
TRAIN_DIRECTORY = "./train" | ||
MIN_EVAL_EPISODES = 5 | ||
MAP_PATH = "./f1tenth_gym/examples/example_map" | ||
MAP_EXTENSION = ".png" | ||
MAP_CHANGE_INTERVAL = 3000 | ||
|
||
|
||
def main(): | ||
|
||
# # | ||
# EVALUATE # | ||
# # | ||
|
||
# prepare the environment | ||
def wrap_env(): | ||
# starts F110 gym | ||
env = gym.make("f110_gym:f110-v0", | ||
map=MAP_PATH, | ||
map_ext=MAP_EXTENSION, | ||
num_agents=1) | ||
# wrap basic gym with RL functions | ||
env = F110_Wrapped(env) | ||
env = RandomMap(env, MAP_CHANGE_INTERVAL) | ||
return env | ||
|
||
# create evaluation environment (same as train environment) | ||
eval_env = wrap_env() | ||
|
||
# set random seed | ||
eval_env.seed(np.random.randint(pow(2, 32) - 1)) | ||
|
||
# load or create model | ||
model, _ = load_model(TRAIN_DIRECTORY, | ||
eval_env, | ||
evaluating=True) | ||
|
||
# simulate a few episodes and render them, ctrl-c to cancel an episode | ||
episode = 0 | ||
while episode < MIN_EVAL_EPISODES: | ||
try: | ||
episode += 1 | ||
obs = eval_env.reset() | ||
done = False | ||
while not done: | ||
# use trained model to predict some action, using observations | ||
action, _ = model.predict(obs) | ||
obs, _, done, _ = eval_env.step(action) | ||
eval_env.render() | ||
# this section just asks the user if they want to run more episodes | ||
if episode == (MIN_EVAL_EPISODES - 1): | ||
choice = input("Another episode? (Y/N) ") | ||
if choice.replace(" ", "").lower() in ["y", "yes"]: | ||
episode -= 1 | ||
else: | ||
episode = MIN_EVAL_EPISODES | ||
except KeyboardInterrupt: | ||
pass | ||
|
||
|
||
def load_model(train_directory, envs, tensorboard_path=None, evaluating=False): | ||
# parse arguments to script | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-l", | ||
"--load", | ||
help="load previous model", | ||
nargs="?", | ||
const="latest") | ||
args = parser.parse_args() | ||
# create new model | ||
if (args.load is None) and (not evaluating): | ||
print("Creating new model...") | ||
reset_num_timesteps = True | ||
model = PPO("MlpPolicy", | ||
envs, | ||
verbose=1, | ||
tensorboard_log=tensorboard_path) | ||
# load model | ||
else: | ||
reset_num_timesteps = False | ||
# get trained model list | ||
trained_models = glob.glob(f"{train_directory}/*") | ||
# latest model | ||
if (args.load == "latest") or (args.load is None): | ||
model_path = max(trained_models, key=os.path.getctime) | ||
else: | ||
trained_models_sorted = sorted(trained_models, | ||
key=os.path.getctime, | ||
reverse=True) | ||
# match user input to model names | ||
model_path = [m for m in trained_models_sorted if args.load in m] | ||
model_path = model_path[0] | ||
# get plain model name for printing | ||
model_name = model_path.replace(".zip", '') | ||
model_name = model_name.replace(f"{train_directory}/", '') | ||
print(f"Loading model ({train_directory}) {model_name}") | ||
# load model from path | ||
model = PPO.load(model_path) | ||
# set and reset environment | ||
model.set_env(envs) | ||
envs.reset() | ||
# return new/loaded model | ||
return model, reset_num_timesteps | ||
|
||
|
||
# necessary for Python multi-processing (not needed in evaluating) | ||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
""" | ||
Stable Baselines 3 training script for F1Tenth Gym with vectorised environments | ||
""" | ||
|
||
import os | ||
import gym | ||
import time | ||
import glob | ||
import argparse | ||
import numpy as np | ||
|
||
from datetime import datetime | ||
|
||
from stable_baselines3 import PPO | ||
from stable_baselines3.common.vec_env import SubprocVecEnv | ||
from stable_baselines3.common.env_util import make_vec_env | ||
|
||
from code.wrappers import F110_Wrapped, RandomMap | ||
|
||
|
||
TRAIN_DIRECTORY = "./train" | ||
TRAIN_STEPS = pow(10, 4) # for reference, it takes about one sec per 500 steps | ||
NUM_PROCESS = 1 | ||
MAP_PATH = "./f1tenth_gym/examples/example_map" | ||
MAP_EXTENSION = ".png" | ||
MAP_CHANGE_INTERVAL = 3000 | ||
TENSORBOARD_PATH = "./ppo_tensorboard" | ||
|
||
|
||
def main(): | ||
|
||
# # | ||
# TRAIN # | ||
# # | ||
|
||
# prepare the environment | ||
def wrap_env(): | ||
# starts F110 gym | ||
env = gym.make("f110_gym:f110-v0", | ||
map=MAP_PATH, | ||
map_ext=MAP_EXTENSION, | ||
num_agents=1) | ||
# wrap basic gym with RL functions | ||
env = F110_Wrapped(env) | ||
env = RandomMap(env, MAP_CHANGE_INTERVAL) | ||
return env | ||
|
||
# vectorise environment (parallelise) | ||
envs = make_vec_env(wrap_env, | ||
n_envs=NUM_PROCESS, | ||
seed=np.random.randint(pow(2, 32) - 1), | ||
vec_env_cls=SubprocVecEnv) | ||
|
||
# load or create model | ||
model, reset_num_timesteps = load_model(TRAIN_DIRECTORY, | ||
envs, | ||
TENSORBOARD_PATH) | ||
|
||
# train model and record time taken | ||
start_time = time.time() | ||
model.learn(total_timesteps=TRAIN_STEPS, | ||
reset_num_timesteps=reset_num_timesteps) | ||
print(f"Training time {time.time() - start_time:.2f}s") | ||
print("Training cycle complete.") | ||
|
||
# save model with unique timestamp | ||
timestamp = datetime.now().strftime("%d-%m-%Y-%H-%M-%S") | ||
model.save(f"{TRAIN_DIRECTORY}/ppo-f110-{timestamp}") | ||
|
||
|
||
def load_model(train_directory, envs, tensorboard_path=None, evaluating=False): | ||
# parse arguments to script | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-l", | ||
"--load", | ||
help="load previous model", | ||
nargs="?", | ||
const="latest") | ||
args = parser.parse_args() | ||
# create new model | ||
if (args.load is None) and (not evaluating): | ||
print("Creating new model...") | ||
reset_num_timesteps = True | ||
model = PPO("MlpPolicy", | ||
envs, | ||
verbose=1, | ||
tensorboard_log=tensorboard_path) | ||
# load model | ||
else: | ||
reset_num_timesteps = False | ||
# get trained model list | ||
trained_models = glob.glob(f"{train_directory}/*") | ||
# latest model | ||
if (args.load == "latest") or (args.load is None): | ||
model_path = max(trained_models, key=os.path.getctime) | ||
else: | ||
trained_models_sorted = sorted(trained_models, | ||
key=os.path.getctime, | ||
reverse=True) | ||
# match user input to model names | ||
model_path = [m for m in trained_models_sorted if args.load in m] | ||
model_path = model_path[0] | ||
# get plain model name for printing | ||
model_name = model_path.replace(".zip", '') | ||
model_name = model_name.replace(f"{train_directory}/", '') | ||
print(f"Loading model ({train_directory}) {model_name}") | ||
# load model from path | ||
model = PPO.load(model_path) | ||
# set and reset environment | ||
model.set_env(envs) | ||
envs.reset() | ||
# return new/loaded model | ||
return model, reset_num_timesteps | ||
|
||
|
||
# necessary for Python multi-processing | ||
if __name__ == "__main__": | ||
main() |