Skip to content

Commit

Permalink
Added script for evaluating saved models
Browse files Browse the repository at this point in the history
  • Loading branch information
shariqiqbal2810 committed Feb 12, 2018
1 parent 5a482d0 commit 88bb46f
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import argparse
import torch
import time
import numpy as np
from pathlib import Path
from torch.autograd import Variable
from utils.make_env import make_env
from algorithms.maddpg import MADDPG


def run(config):
model_path = (Path('./models') / config.env_id / config.model_name /
('run%i' % config.run_num) / 'model.pt')

env = make_env(config.env_id)
maddpg = MADDPG.init_from_save(model_path)
ifi = 1 / config.fps # inter-frame interval

for ep_i in range(config.n_episodes):
print("Episode %i of %i" % (ep_i + 1, config.n_episodes))
obs = env.reset()
env.render('human')
for t_i in range(config.episode_length):
calc_start = time.time()
# rearrange observations to be per agent, and convert to torch Variable
torch_obs = [Variable(torch.Tensor(obs[i]),
requires_grad=False)
for i in range(maddpg.nagents)]
# get actions as torch Variables
torch_actions = maddpg.step(torch_obs, training=False)
# convert actions to numpy arrays
actions = [ac.data.numpy() for ac in torch_actions]
obs, rewards, dones, infos = env.step(actions)
calc_end = time.time()
elapsed = calc_end - calc_start
if elapsed < ifi:
time.sleep(ifi - elapsed)
env.render('human')

env.close()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("env_id", help="Name of environment")
parser.add_argument("model_name",
help="Name of directory to store " +
"model/training contents")
parser.add_argument("run_num", default=1, type=int)
parser.add_argument("--n_episodes", default=10, type=int)
parser.add_argument("--episode_length", default=100, type=int)
parser.add_argument("--fps", default=30, type=int)

config = parser.parse_args()

run(config)

0 comments on commit 88bb46f

Please sign in to comment.