Skip to content

Commit

Permalink
Allow setting of kwargs in init_torch_dist_process_group (ray-proje…
Browse files Browse the repository at this point in the history
…ct#42099)

This PR allows for more configurability in `init_torch_dist_process_group` function by enabling the passing of user defined kwargs to the `torch.distributed.init_process_group` function. Crucially, this allows for the timeout argument to be specified by the user.

---------

Signed-off-by: Antoni Baum <[email protected]>
Co-authored-by: Justin Yu <[email protected]>
  • Loading branch information
Yard1 and justinvyu authored Dec 27, 2023
1 parent d8222b2 commit 394d363
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions python/ray/air/util/torch_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _init_torch_distributed(
master_addr: str,
master_port: str,
gpu_ids: List[int],
**init_process_group_kwargs,
):
"""Initialize torch distributed backend"""
if init_method == "env":
Expand All @@ -71,13 +72,17 @@ def _init_torch_distributed(
if "NCCL_SOCKET_IFNAME" not in os.environ:
os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_NCCL_SOCKET_IFNAME

dist.init_process_group(
backend=backend,
init_method=url,
rank=rank,
world_size=world_size,
timeout=timedelta(seconds=1800),
init_process_group_kwargs.update(
dict(
backend=backend,
init_method=url,
rank=rank,
world_size=world_size,
)
)
init_process_group_kwargs.setdefault("timeout", timedelta(seconds=1800))

dist.init_process_group(**init_process_group_kwargs)

os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(local_rank)
Expand All @@ -96,6 +101,7 @@ def init_torch_dist_process_group(
workers: List[ActorHandle],
backend: str = "gloo",
init_method: str = "env",
**init_process_group_kwargs,
) -> List[int]:
"""Initialize a torch distributed process group.
Expand All @@ -108,6 +114,8 @@ def init_torch_dist_process_group(
possible choices are "gloo" or "nccl".
init_method: The initialization method to use,
possible choices are "env" or "tcp".
init_process_group_kwargs: Additional kwargs to pass to the call to
:meth:`torch.distributed.init_process_group`.
Returns:
Local ranks on their respective nodes for the list of workers.
Expand Down Expand Up @@ -156,6 +164,7 @@ def init_torch_dist_process_group(
# list(set) will sort the gpu ids, so VISIBLE_CUDA_DEVICES
# is always sorted.
gpu_ids=list(node_to_gpu_ids[node_id]),
**init_process_group_kwargs,
)
)
local_ranks.append(local_rank)
Expand Down

0 comments on commit 394d363

Please sign in to comment.