Skip to content

Commit

Permalink
Remove more mujoco code
Browse files Browse the repository at this point in the history
  • Loading branch information
PavelCz committed Aug 9, 2022
1 parent fdee2a2 commit 99e6254
Showing 1 changed file with 1 addition and 72 deletions.
73 changes: 1 addition & 72 deletions src/dep/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,73 +201,6 @@ def eval_checkpoint(
}


def eval_from_mujoco_state(
mujoco_state_path: str,
agent_alg: str = "ppo",
render: bool = False,
num_steps: int = 10000,
wandb_project: str = "pbrl-defense",
wandb_group: str = "dev",
env_name="mpe",
scenario_name="simple_push",
artifact_dir: Union[str, Path] = Path("/scratch/pavel/out/tmp"),
video_dir=None,
local_mode: bool = False,
) -> dict:
artifact_dir = Path(artifact_dir)
wandb.init(project=wandb_project, group=wandb_group, dir=artifact_dir)

ray.init(local_mode=local_mode)

trainer_class = trainer_cls_from_str(agent_alg)

mujoco_state_file = Path(mujoco_state_path).open(mode="rb")
mujoco_state = pickle.load(mujoco_state_file)

_ = init_env(
env_name=env_name,
scenario_name=scenario_name,
max_steps=25,
mujoco_state=mujoco_state,
)

update_config = _update_config_for_eval(
{
"env_config": {"scenario_name": scenario_name},
"num_workers": 0,
"evaluation_num_workers": 0,
"custom_eval_function": None,
"evaluation_duration": 1,
},
scenario_name,
)

trainer = trainer_class(
env="current-env", config=update_config, logger_creator=noop_logger_creator
)
# trainer = restore_trainer_from_path(file, scenario_name, trainer_class, config_update=update_config)

# Now the trainer object should have the agent0 weights in policy0 and agent1 weights in policy1
saver = InMemoryRolloutSaver()

# Suppress command line output because rollout() is a bit too verbose for my taste
with contextlib.redirect_stdout(None):
rollout(
trainer,
None, # Parameter is unused
num_steps=num_steps,
saver=saver,
no_render=not render,
video_dir=video_dir,
)

ray.shutdown()

return {
"mean_rewards": saver.mean_rewards,
}


def _update_config_for_eval(config, scenario_name):
"""Updates the config for evaluation. Code from rllib/rollout.py"""

Expand Down Expand Up @@ -299,9 +232,5 @@ def _update_config_for_eval(config, scenario_name):

if __name__ == "__main__":
fire.Fire(
{
"checkpoint": eval_checkpoint,
"mujoco-state": eval_from_mujoco_state,
"multi": multi_eval,
}
{"checkpoint": eval_checkpoint, "multi": multi_eval,}
)

0 comments on commit 99e6254

Please sign in to comment.