Skip to content

Commit

Permalink
adding the score for the icaps 2021 as part of the utils functions
Browse files Browse the repository at this point in the history
  • Loading branch information
BDonnot committed Jun 21, 2021
1 parent 479c4d8 commit d5b7b58
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 225 deletions.
46 changes: 44 additions & 2 deletions grid2op/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from grid2op.MakeEnv import make
from grid2op.dtypes import dt_float
from grid2op.Agent import DoNothingAgent, RecoPowerlineAgent
from grid2op.utils import EpisodeStatistics, ScoreL2RPN2020
from grid2op.utils import EpisodeStatistics, ScoreL2RPN2020, ScoreICAPS2021
from grid2op.Parameters import Parameters

import warnings
Expand Down Expand Up @@ -124,7 +124,6 @@ def test_can_compute(self):
assert not os.path.exists(os.path.join(env.get_path_env(),
EpisodeStatistics.get_name_dir(ScoreL2RPN2020.NAME_RP_NO_OVERWLOW)))


def test_donothing_0(self):
"""test that do nothing has a score of 0.00"""
with warnings.catch_warnings():
Expand Down Expand Up @@ -321,5 +320,48 @@ def test_reco_noov_80(self):
EpisodeStatistics.get_name_dir(ScoreL2RPN2020.NAME_RP_NO_OVERWLOW)))


class TestICAPSSCORE(HelperTests):
"""test teh grid2op.utils.EpisodeStatistics """
def test_can_compute(self):
"""test that i can initialize the score and then delete the statistics"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
with make(os.path.join(PATH_DATA_TEST, "l2rpn_neurips_2020_track1_with_alert"), test=True) as env:
scores = ScoreICAPS2021(env,
nb_scenario=2,
verbose=0,
max_step=50,
env_seeds=[1, 2], # with these seeds do nothing goes till the end
agent_seeds=[3, 4])
my_agent = DoNothingAgent(env.action_space)
scores, n_played, total_ts = scores.get(my_agent)
for (ep_score, op_score, alarm_score) in scores:
assert np.abs(ep_score - 30.) <= self.tol_one, f"wrong score for the episode: {ep_score} vs 30."
assert np.abs(op_score - 0.) <= self.tol_one, f"wrong score for the operationnal cost: " \
f"{op_score} vs 0."
assert np.abs(alarm_score - 100.) <= self.tol_one, f"wrong score for the alarm: " \
f"{alarm_score} vs 100."

# the statistics have been properly computed
assert os.path.exists(os.path.join(env.get_path_env(),
EpisodeStatistics.get_name_dir(ScoreICAPS2021.NAME_DN)))
assert os.path.exists(os.path.join(env.get_path_env(),
EpisodeStatistics.get_name_dir(ScoreICAPS2021.NAME_DN_NO_OVERWLOW)))

# delete them
stats_0 = EpisodeStatistics(env, ScoreICAPS2021.NAME_DN)
stats_1 = EpisodeStatistics(env, ScoreICAPS2021.NAME_DN_NO_OVERWLOW)
stats_2 = EpisodeStatistics(env, ScoreICAPS2021.NAME_RP_NO_OVERWLOW)
stats_0.clear_all()
stats_1.clear_all()
stats_2.clear_all()
assert not os.path.exists(os.path.join(env.get_path_env(),
EpisodeStatistics.get_name_dir(ScoreICAPS2021.NAME_DN)))
assert not os.path.exists(os.path.join(env.get_path_env(),
EpisodeStatistics.get_name_dir(ScoreICAPS2021.NAME_DN_NO_OVERWLOW)))
assert not os.path.exists(os.path.join(env.get_path_env(),
EpisodeStatistics.get_name_dir(ScoreICAPS2021.NAME_RP_NO_OVERWLOW)))


if __name__ == "__main__":
unittest.main()
257 changes: 53 additions & 204 deletions grid2op/utils/icaps_2021_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tempfile

from grid2op.dtypes import dt_float
from grid2op.Reward import L2RPNSandBoxScore, AlarmReward
from grid2op.Reward import L2RPNSandBoxScore, _AlarmScore
from grid2op.utils.underlying_statistics import EpisodeStatistics
from grid2op.utils.l2rpn_2020_scores import ScoreL2RPN2020
from grid2op.Episode import EpisodeData
Expand Down Expand Up @@ -56,9 +56,9 @@ class ScoreICAPS2021(ScoreL2RPN2020):
"""

NAME_DN = "icaps_dn"
NAME_DN_NO_OVERWLOW = "icaps_no_overflow"
NAME_RP_NO_OVERWLOW = "icaps_no_overflow_reco"
NAME_DN = "icaps2021_dn"
NAME_DN_NO_OVERWLOW = "icaps2021_no_overflow"
NAME_RP_NO_OVERWLOW = "icaps2021_no_overflow_reco"

def __init__(self,
env,
Expand All @@ -68,224 +68,73 @@ def __init__(self,
min_losses_ratio=0.8,
verbose=0,
max_step=-1,
nb_process_stats=1):
ScoreL2RPN2020.__init__(self,
env,
env_seeds,
agent_seeds,
nb_scenario,
min_losses_ratio,
verbose, max_step,
nb_process_stats)
self.scores_func = L2RPNSandBoxScore

def _init_stat(self, stat, stat_name, computed_scenarios, parameters=None, nb_process_stats=1, agent=None):
"""will check if the statistics need to be computed"""
need_recompute = True
if EpisodeStatistics.get_name_dir(stat_name) in computed_scenarios:
# the things have been computed i check if the number of scenarios is big enough
scores, ids_ = stat.get(EpisodeStatistics.SCORES)
metadata = stat.get_metadata()
max_id = np.max(ids_)

# i need to recompute if if i did not compute enough scenarios
need_recompute = max_id < self.nb_scenario - 1

# if max
computed_step = int(metadata["max_step"])
if computed_step > 0:
# if i have computed the data with
if self.max_step == -1:
# i need to compute now all the dataset, so yes i have to recompute it
need_recompute = True

# i need to recompute only if i ask more steps than what was computed
need_recompute = need_recompute or self.max_step > metadata["max_step"]
nb_process_stats=1,
scale_alarm_score=100.,
weight_op_score=0.7,
weight_alarm_score=0.3,
):

# TODO check for the seeds here too
# TODO and check for the class of the scores too
# TODO check for the parameters too...

if need_recompute:
# i need to compute it
if self.verbose >= 1:
print("I need to recompute the statistics for this environment. This will take a while") # TODO logger
stat.compute(nb_scenario=self.nb_scenario,
pbar=self.verbose >= 2,
env_seeds=self.env_seeds,
agent_seeds=self.agent_seeds,
scores_func=L2RPNSandBoxScore,
max_step=self.max_step,
parameters=parameters,
nb_process=nb_process_stats)
stat.clear_episode_data()
return need_recompute
ScoreL2RPN2020.__init__(self,
env=env,
env_seeds=env_seeds,
agent_seeds=agent_seeds,
nb_scenario=nb_scenario,
min_losses_ratio=min_losses_ratio,
verbose=verbose,
max_step=max_step,
nb_process_stats=nb_process_stats,
scores_func={"grid_operational_cost": L2RPNSandBoxScore,
"alarm_cost": _AlarmScore},
score_names=["grid_operational_cost_scores", "alarm_cost_scores"])
self.scale_alarm_score = scale_alarm_score
self.weight_op_score = weight_op_score
self.weight_alarm_score = weight_alarm_score

def _compute_episode_score(self,
ep_id, # the ID here, which is an integer and is not the ID from chronics balblabla
meta,
other_rewards,
dn_metadata,
no_ov_metadata):
no_ov_metadata,
score_file_to_use=None):
"""
Performs the rescaling of the score given the information stored in the "statistics" of this
environment.
"""
load_p, ids = self.stat_no_overflow.get("load_p")
prod_p, _ = self.stat_no_overflow.get("prod_p")

scores_dn, ids_dn_sc = self.stat_dn.get(EpisodeStatistics.SCORES)
scores_no_ov, ids_noov_sc = self.stat_no_overflow.get(EpisodeStatistics.SCORES)

# reshape to have 1 dim array
ids = ids.reshape(-1)
ids_dn_sc = ids_dn_sc.reshape(-1)
ids_noov_sc = ids_noov_sc.reshape(-1)

# there is a hugly "1" at the end of each scores due to the "game over" (or end of game), so i remove it
scores_dn = scores_dn[ids_dn_sc == ep_id][:-1]
scores_no_ov = scores_no_ov[ids_noov_sc == ep_id][:-1]

dn_this = dn_metadata[f"{ep_id}"]
no_ov_this = no_ov_metadata[f"{ep_id}"]

n_played = int(meta["nb_timestep_played"])
dn_step_played = dn_this["nb_step"] - 1
total_ts = no_ov_this["nb_step"] - 1

ep_marginal_cost = np.max(self.env.gen_cost_per_MW).astype(dt_float)
min_losses_ratio = self.min_losses_ratio

# remember that first observation do not count (it's generated by the environment)
ep_loads = np.sum(load_p[ids == ep_id, :], axis=1)[1:]
ep_losses = np.sum(prod_p[ids == ep_id, :], axis=1)[1:] - ep_loads

if self.max_step > 0:
scores_dn = scores_dn[:self.max_step]
scores_no_ov = scores_no_ov[:self.max_step]
ep_loads = ep_loads[:self.max_step]
ep_losses = ep_losses[:self.max_step]

# do nothing operationnal cost
ep_do_nothing_operat_cost = np.sum(scores_dn)
ep_do_nothing_operat_cost += np.sum(ep_loads[dn_step_played:]) * ep_marginal_cost

# no overflow disconnection cost
ep_do_nothing_nodisc_cost = np.sum(scores_no_ov)

# this agent cumulated operationnal cost
# same as above: i remove the last element which correspond to the last state, so irrelevant
ep_cost = np.array([el[EpisodeStatistics.KEY_SCORE] for el in other_rewards]).astype(dt_float)
if dn_metadata["max_step"] == self.max_step:
ep_cost = ep_cost[:-1]
ep_cost = np.sum(ep_cost)
ep_cost += np.sum(ep_loads[n_played:]) * ep_marginal_cost

# Compute ranges
worst_operat_cost = np.sum(ep_loads) * ep_marginal_cost # operational cost corresponding to the min score
zero_operat_score = ep_do_nothing_operat_cost
nodisc_oeprat_cost = ep_do_nothing_nodisc_cost
best_score = np.sum(ep_losses) * min_losses_ratio # operational cost corresponding to the max score

# Linear interp episode reward to codalab score
if zero_operat_score != nodisc_oeprat_cost:
# DoNothing agent doesnt complete the scenario
reward_range = [best_score, nodisc_oeprat_cost, zero_operat_score, worst_operat_cost]
score_range = [100.0, 80.0, 0.0, -100.0]
else:
# DoNothing agent can complete the scenario
reward_range = [best_score, zero_operat_score, worst_operat_cost]
score_range = [100.0, 0.0, -100.0]
ep_score = np.interp(ep_cost, reward_range, score_range)
return ep_score, n_played, total_ts

def get(self, agent, path_save=None, nb_process=1):
"""
Get the score of the agent depending on what has been computed.
TODO The plots will be done later.
Parameters
----------
agent: :class:`grid2op.Agent.BaseAgent`
The agent you want to score
path_save: ``str``
the path were you want to store the logs of your agent.
nb_process: ``int``
Number of process to use for the evaluation
Returns
-------
all_scores: ``list``
List of the score of your agent per scenarios
ts_survived: ``list``
List of the number of step your agent successfully managed for each scenario
total_ts: ``list``
Total number of step for each scenario
"""
if path_save is not None:
need_delete = False # TODO this is soooo dirty
path_save = os.path.abspath(path_save)
else:
need_delete = True
dir_tmp = tempfile.TemporaryDirectory()
path_save = dir_tmp.name

if self.verbose >= 1:
print("Starts the evaluation of the agent") # TODO logger
EpisodeStatistics.run_env(self.env,
env_seeds=self.env_seeds,
agent_seeds=self.agent_seeds,
path_save=path_save,
parameters=self.env.parameters,
scores_func=L2RPNSandBoxScore,
agent=agent,
max_step=self.max_step,
nb_scenario=self.nb_scenario,
pbar=self.verbose >= 2,
nb_process=nb_process,
)
if self.verbose >= 1:
print("Start the evaluation of the scores") # TODO logger

meta_data_dn = self.stat_dn.get_metadata()
no_ov_metadata = self.stat_no_overflow.get_metadata()

all_scores = []
ts_survived = []
total_ts = []
for ep_id in range(self.nb_scenario):
this_ep_nm = meta_data_dn[f"{ep_id}"]["scenario_name"]
with open(os.path.join(path_save, this_ep_nm, EpisodeData.META), "r", encoding="utf-8") as f:
this_epi_meta = json.load(f)
with open(os.path.join(path_save, this_ep_nm, EpisodeData.OTHER_REWARDS), "r", encoding="utf-8") as f:
this_epi_scores = json.load(f)
score_this_ep, nb_ts_survived, total_ts_tmp = \
self._compute_episode_score(ep_id,
meta=this_epi_meta,
other_rewards=this_epi_scores,
dn_metadata=meta_data_dn,
no_ov_metadata=no_ov_metadata)
all_scores.append(score_this_ep)
ts_survived.append(nb_ts_survived)
total_ts.append(total_ts_tmp)
if need_delete:
dir_tmp.cleanup()

return all_scores, ts_survived, total_ts
# compute the operational score
op_score, n_played, total_ts = super()._compute_episode_score(ep_id,
meta,
other_rewards,
dn_metadata,
no_ov_metadata,
# score_file_to_use should match the
# L2RPNSandBoxScore key in
# self.scores_func
score_file_to_use="grid_operational_cost_scores",
)
# should match underlying_statistics.run_env `dict_kwg["other_rewards"][XXX] = ...`
# XXX is right now f"{EpisodeStatistics.KEY_SCORE}_{nm}" [this should match the XXX]
score_file_to_use = "alarm_cost_scores"
real_nm = EpisodeStatistics._nm_score_from_attr_name(score_file_to_use)
key_score_file = f"{EpisodeStatistics.KEY_SCORE}_{real_nm}"

alarm_score = float(other_rewards[-1][key_score_file])
alarm_score = self.scale_alarm_score * alarm_score

ep_score = self.weight_op_score * op_score + self.weight_alarm_score * alarm_score
return (ep_score, op_score, alarm_score), n_played, total_ts


if __name__ == "__main__":
import grid2op
from lightsim2grid import LightSimBackend
from grid2op.Agent import RandomAgent, DoNothingAgent
env = grid2op.make("l2rpn_case14_sandbox", backend=LightSimBackend())
nb_scenario = 16
my_score = ScoreL2RPN2020(env,
env = grid2op.make("/home/benjamin/Documents/grid2op_dev/grid2op/data_test/l2rpn_neurips_2020_track1_with_alert",
backend=LightSimBackend())
# env = grid2op.make("l2rpn_case14_sandbox", backend=LightSimBackend())
nb_scenario = 2
my_score = ScoreICAPS2021(env,
nb_scenario=nb_scenario,
env_seeds=[0 for _ in range(nb_scenario)],
agent_seeds=[0 for _ in range(nb_scenario)]
Expand Down
Loading

0 comments on commit d5b7b58

Please sign in to comment.