Skip to content

Commit

Permalink
[RLlib] Replace ordinary pygame imports by try_import_..(). (ray-pr…
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturNiederfahrenhorst authored Jan 17, 2023
1 parent 3f5a5c7 commit 5e66cbf
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 14 deletions.
5 changes: 4 additions & 1 deletion rllib/algorithms/alpha_star/tests/test_alpha_star.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pyspiel
import unittest

import ray
import ray.rllib.algorithms.alpha_star as alpha_star
from ray.rllib.env.utils import try_import_pyspiel, try_import_open_spiel
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
from ray.rllib.utils.test_utils import (
check_compute_single_action,
Expand All @@ -11,6 +11,9 @@
)
from ray.tune import register_env

open_spiel = try_import_open_spiel(error=True)
pyspiel = try_import_pyspiel(error=True)

# Connect-4 OpenSpiel env.
register_env("connect_four", lambda _: OpenSpielEnv(pyspiel.load_game("connect_four")))

Expand Down
55 changes: 55 additions & 0 deletions rllib/env/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,61 @@

from ray.rllib.env.env_context import EnvContext
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
from ray.util.annotations import PublicAPI


@PublicAPI
def try_import_pyspiel(error: bool = False):
"""Tries importing pyspiel and returns the module (or None).
Args:
error: Whether to raise an error if pyspiel cannot be imported.
Returns:
The pyspiel module.
Raises:
ImportError: If error=True and pyspiel is not installed.
"""
try:
import pyspiel

return pyspiel
except ImportError:
if error:
raise ImportError(
"Could not import pyspiel! Pygame is not a dependency of RLlib "
"and RLlib requires you to install pygame separately: "
"`pip install pygame`."
)
return None


@PublicAPI
def try_import_open_spiel(error: bool = False):
"""Tries importing open_spiel and returns the module (or None).
Args:
error: Whether to raise an error if open_spiel cannot be imported.
Returns:
The open_spiel module.
Raises:
ImportError: If error=True and open_spiel is not installed.
"""
try:
import open_spiel

return open_spiel
except ImportError:
if error:
raise ImportError(
"Could not import open_spiel! open_spiel is not a dependency of RLlib "
"and RLlib requires you to install open_spiel separately: "
"`pip install open_spiel`."
)
return None


def _gym_env_creator(env_context: EnvContext, env_descriptor: str) -> gym.Env:
Expand Down
9 changes: 6 additions & 3 deletions rllib/env/wrappers/open_spiel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from gymnasium.spaces import Box, Discrete
import numpy as np
import pyspiel
from typing import Optional

import numpy as np
from gymnasium.spaces import Box, Discrete

from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.utils import try_import_pyspiel

pyspiel = try_import_pyspiel(error=True)


class OpenSpielEnv(MultiAgentEnv):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@

import argparse
from pathlib import Path
import pyspiel

import ray
from ray import air, tune
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.sac import SACConfig
from ray.rllib.env.utils import try_import_pyspiel
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
from ray.rllib.policy.policy import Policy
from ray.tune import CLIReporter, register_env

pyspiel = try_import_pyspiel(error=True)

parser = argparse.ArgumentParser()
# This should point to a checkpointed policy that plays connect_four.
# Note that this policy may be trained with different algorithms than
Expand Down
16 changes: 11 additions & 5 deletions rllib/examples/self_play_league_based_with_open_spiel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,28 @@
"""

import argparse
import numpy as np
import os
from open_spiel.python.rl_environment import Environment
import pyspiel
import re

import numpy as np

import ray
from ray import air, tune
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.self_play_with_open_spiel import ask_user_for_action
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.env.utils import try_import_pyspiel, try_import_open_spiel
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.examples.self_play_with_open_spiel import ask_user_for_action
from ray.rllib.policy.policy import PolicySpec
from ray.tune import register_env

open_spiel = try_import_open_spiel(error=True)
pyspiel = try_import_pyspiel(error=True)

# Import after try_import_open_spiel, so we can error out with hints
from open_spiel.python.rl_environment import Environment # noqa: E402


def get_cli_args():
"""Create CLI parser and return parsed arguments"""
Expand Down
14 changes: 10 additions & 4 deletions rllib/examples/self_play_with_open_spiel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,27 @@
"""

import argparse
import numpy as np
import os
import pyspiel
from open_spiel.python.rl_environment import Environment
import sys

import numpy as np

import ray
from ray import air, tune
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.env.utils import try_import_pyspiel, try_import_open_spiel
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.policy.policy import PolicySpec
from ray.tune import CLIReporter, register_env

open_spiel = try_import_open_spiel(error=True)
pyspiel = try_import_pyspiel(error=True)

# Import after try_import_open_spiel, so we can error out with hints
from open_spiel.python.rl_environment import Environment # noqa: E402


def get_cli_args():
"""Create CLI parser and return parsed arguments"""
Expand Down

0 comments on commit 5e66cbf

Please sign in to comment.