Skip to content

Commit

Permalink
[SGD] add share_cuda_visible_devices config flag (ray-project#18958)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdeng authored Sep 29, 2021
1 parent 505aa89 commit 91a5f67
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 6 deletions.
26 changes: 22 additions & 4 deletions python/ray/util/sgd/v2/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ray.util.sgd.v2.checkpoint import CheckpointStrategy
from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \
TUNE_INSTALLED, TUNE_CHECKPOINT_FILE_NAME, \
TUNE_CHECKPOINT_ID
TUNE_CHECKPOINT_ID, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV
from ray.util.sgd.v2.session import TrainingResultType, TrainingResult
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
from ray.util.sgd.v2.utils import construct_path, check_for_failure
Expand Down Expand Up @@ -275,15 +275,21 @@ def start(self,
if initialization_hook:
self._initialization_hook = initialization_hook
self.worker_group.execute(initialization_hook)
if self._num_gpus_per_worker > 0:
self._setup_gpus()

share_cuda_visible_devices_enabled = bool(
env_integer(ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
self._backend.share_cuda_visible_devices))

if (self._num_gpus_per_worker > 0
and share_cuda_visible_devices_enabled):
self._share_cuda_visible_devices()
self._backend.on_start(self.worker_group, self._backend_config)
except RayActorError as exc:
logger.exception(str(exc))
self._increment_failures()
self._restart()

def _setup_gpus(self):
def _share_cuda_visible_devices(self):
"""Sets CUDA_VISIBLE_DEVICES on all workers.
For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs
Expand Down Expand Up @@ -685,6 +691,18 @@ def _increment_failures(self):


class Backend(metaclass=abc.ABCMeta):
"""Metaclass for distributed communication backend.
Attributes:
share_cuda_visible_devices (bool): If True, each worker
process will have CUDA_VISIBLE_DEVICES set as the visible device
IDs of all workers on the same node for this training instance.
If False, each worker will have CUDA_VISIBLE_DEVICES set to the
device IDs allocated by Ray for that worker.
"""

share_cuda_visible_devices: bool = False

def on_start(self, worker_group: WorkerGroup,
backend_config: BackendConfig):
"""Logic for starting this backend."""
Expand Down
2 changes: 2 additions & 0 deletions python/ray/util/sgd/v2/backends/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def init_env_vars(world_rank: int, world_size: int, node_id: str):


class HorovodBackend(Backend):
share_cuda_visible_devices: bool = True

def on_start(self, worker_group: WorkerGroup,
backend_config: HorovodConfig):

Expand Down
2 changes: 2 additions & 0 deletions python/ray/util/sgd/v2/backends/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def shutdown_torch(destroy_process_group=False):


class TorchBackend(Backend):
share_cuda_visible_devices: bool = True

def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig):
if len(worker_group) > 1 and dist.is_available():
# Set the appropriate training backend.
Expand Down
4 changes: 4 additions & 0 deletions python/ray/util/sgd/v2/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@
# This needs to be added to the checkpoint dictionary so if the Tune trial
# is restarted, the checkpoint_id can continue to increment.
TUNE_CHECKPOINT_ID = "_current_checkpoint_id"

# Integer value which if set will override the value of
# Backend.share_cuda_visible_devices. 1 for True, 0 for False.
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV = "SGD_ENABLE_SHARE_CUDA_VISIBLE_DEVICES"
4 changes: 2 additions & 2 deletions python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def train_func(config):
return results


def train_tensorflow_mnist(num_workers=1, use_gpu=False):
def train_tensorflow_mnist(num_workers=2, use_gpu=False):
trainer = Trainer(
backend="tensorflow", num_workers=num_workers, use_gpu=use_gpu)
trainer.start()
Expand All @@ -98,7 +98,7 @@ def train_tensorflow_mnist(num_workers=1, use_gpu=False):
"--num-workers",
"-n",
type=int,
default=1,
default=2,
help="Sets number of workers for training.")
parser.add_argument(
"--use-gpu",
Expand Down
4 changes: 4 additions & 0 deletions python/ray/util/sgd/v2/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.util.sgd import v2 as sgd
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor
from ray.util.sgd.v2.backends.tensorflow import TensorflowConfig
from ray.util.sgd.v2.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV
from ray.util.sgd.v2.worker_group import WorkerGroup
from ray.util.sgd.v2.backends.torch import TorchConfig

Expand Down Expand Up @@ -321,6 +322,7 @@ def get_resources():

num_workers, expected_results = worker_results

os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1"
e = BackendExecutor(
config,
num_workers=num_workers,
Expand Down Expand Up @@ -349,6 +351,7 @@ def get_resources():

num_workers, expected_results = worker_results

os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1"
e = BackendExecutor(
config,
num_workers=num_workers,
Expand All @@ -374,6 +377,7 @@ def get_resources():

num_workers, expected_results = worker_results

os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1"
e = BackendExecutor(
config,
num_workers=num_workers,
Expand Down
3 changes: 3 additions & 0 deletions python/ray/util/sgd/v2/tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ray.util.sgd.v2.backends.backend import BackendConfig, Backend, \
BackendExecutor
from ray.util.sgd.v2.callbacks.callback import SGDCallback
from ray.util.sgd.v2.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV
from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \
tensorflow_mnist_train_func
from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \
Expand Down Expand Up @@ -1006,6 +1007,8 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
def get_resources():
return os.environ["CUDA_VISIBLE_DEVICES"]

os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1"

# 0 GPUs will be requested and should not raise an error.
trainer = Trainer(TestConfig(), num_workers=2, use_gpu=False)
trainer.start()
Expand Down

0 comments on commit 91a5f67

Please sign in to comment.