Skip to content

Commit

Permalink
Now the benchmarker re-writes ALL the parameters at EACH row -> MEGA-…
Browse files Browse the repository at this point in the history
…VERBOSE!
  • Loading branch information
erwanlecarpentier committed Dec 7, 2018
1 parent b0e0ed8 commit e86a642
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions dyna_gym/utils/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,20 @@ def run(agent, env, tmax, verbose=False):
break
return cr

def benchmark(env_name, n_env, agent_name_pool, agent_pool, param_pool, n_epi, tmax, save=True, paths_pool=['log.csv'], verbose=True):
def benchmark(env_name, n_env, agent_name_pool, agent_pool, param_pool, param_names_pool, n_epi, tmax, save=True, paths_pool=['log.csv'], verbose=True):
"""
Benchmark a single agent within an environment.
env_name : name of the generated environment
n_env : number of generated environment
agent_name_pool : list containing the names of the agents for saving purpose
agent_pool : list containing the agent objects
param_pool : list containing lists of parameters for each agent object
n_epi : number of episodes per generated environment
tmax : timeout for each episode
save : save the results or not
paths_pool : list containing the saving path for each agent
verbose : if true, display informations during benchmark
env_name : name of the generated environment
n_env : number of generated environment
agent_name_pool : list containing the names of the agents for saving purpose
agent_pool : list containing the agent objects
param_pool : list containing lists of parameters for each agent object
param_names_pool : list containing the parameters names
n_epi : number of episodes per generated environment
tmax : timeout for each episode
save : save the results or not
paths_pool : list containing the saving path for each agent
verbose : if true, display informations during benchmark
"""
assert len(agent_name_pool) == len(agent_pool) == len(param_pool)
n_agents = len(param_pool)
Expand All @@ -64,10 +65,11 @@ def benchmark(env_name, n_env, agent_name_pool, agent_pool, param_pool, n_epi, t
if verbose:
print('Created environment', i+1, '/', n_env)
for j in range(n_agents):
if save:
csv_write(['env_name', 'env_number', 'agent_name', 'agent_number', 'epi_number', 'score'], paths_pool[j], 'w')
agent = agent_pool[j]
n_agents_j = len(param_pool[j])
param_names_j = param_names_pool[j]
if save:
csv_write(['env_name', 'env_number', 'agent_name', 'agent_number'] + param_names_j + ['epi_number', 'score'], paths_pool[j], 'w')
for k in range(n_agents_j):
agent.reset(param_pool[j][k])
if verbose:
Expand All @@ -79,7 +81,7 @@ def benchmark(env_name, n_env, agent_name_pool, agent_pool, param_pool, n_epi, t
env.reset()
score = run(agent, env, tmax)
if save:
csv_write([env_name, i, agent_name_pool[j], k, l, score], paths_pool[j], 'a')
csv_write([env_name, i, agent_name_pool[j], k] + param_pool[j][k] + [l, score], paths_pool[j], 'a')

def test():
"""
Expand All @@ -92,10 +94,14 @@ def test():

agent_name_pool = ['UCT','RANDOM']
agent_pool = [uct.UCT(env.action_space), ra.MyRandomAgent(env.action_space)]
param_names_pool = [
['action_space','rollouts','horizon','gamma','ucb_constant','is_model_dynamic'],
['action_space']
]
param_pool = [
[[env.action_space, 10, 100, 0.9, 6.36396103068, True],[env.action_space, 100, 100, 0.9, 6.36396103068, True]],
[[env.action_space],[env.action_space],[env.action_space]]
]
paths_pool = ['uct.csv','random.csv']

benchmark('NSFrozenLakeEnv-v0', nenv, agent_name_pool, agent_pool, param_pool, nepi, tmax, save=True, paths_pool=paths_pool, verbose=True)
benchmark('NSFrozenLakeEnv-v0', nenv, agent_name_pool, agent_pool, param_pool, param_names_pool, nepi, tmax, save=True, paths_pool=paths_pool, verbose=True)

0 comments on commit e86a642

Please sign in to comment.