Skip to content

Commit

Permalink
Add utils for distributed init methods
Browse files Browse the repository at this point in the history
Summary: Add support for TCP and shared file system initialization schemes: https://pytorch.org/docs/stable/distributed.html#initialization

Reviewed By: daniellepintz

Differential Revision: D46665222

fbshipit-source-id: 37ef615812efbee0fdb7d495c9cabefe5d4f96c6
  • Loading branch information
ananthsub authored and facebook-github-bot committed Jun 13, 2023
1 parent b4c97c2 commit 9b3b7b1
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 2 deletions.
152 changes: 152 additions & 0 deletions tests/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,22 @@
import unittest
from typing import Optional
from unittest.mock import patch
from urllib.parse import parse_qs, urlparse

import torch
import torch.distributed as dist
import torch.distributed.launcher as launcher
from pyre_extensions import none_throws
from torchtnt.utils.device import get_device_from_env
from torchtnt.utils.distributed import (
_validate_global_rank_world_size,
all_gather_tensors,
destroy_process_group,
get_file_init_method,
get_global_rank,
get_local_rank,
get_process_group_backend_from_device,
get_tcp_init_method,
get_world_size,
rank_zero_fn,
revert_sync_batchnorm,
Expand Down Expand Up @@ -273,6 +278,9 @@ def test_sync_bool_multi_process_coherence_mode_int_false(self) -> None:
self.assertFalse(result[0])
self.assertFalse(result[1])

@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
)
def test_sync_bool_multi_process_coherence_mode_int_true(self) -> None:
config = get_pet_launch_config(2)
result = launcher.elastic_launch(config, entrypoint=self._full_sync_worker)(1)
Expand All @@ -290,9 +298,153 @@ def test_sync_bool_multi_process_coherence_mode_float_true(self) -> None:
self.assertTrue(result[0])
self.assertTrue(result[1])

@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
)
def test_sync_bool_multi_process_coherence_mode_float_false(self) -> None:
config = get_pet_launch_config(2)
result = launcher.elastic_launch(config, entrypoint=self._full_sync_worker)(1.0)
# Both processes should return False since 100% of processes don't input True
self.assertFalse(result[0])
self.assertFalse(result[1])

def test_validate_global_rank_world_size(self) -> None:
with self.assertRaisesRegex(ValueError, "Invalid world_size value provided"):
world_size = -1
rank = 0
_validate_global_rank_world_size(world_size=world_size, rank=rank)

with self.assertRaisesRegex(ValueError, "Invalid rank value provided"):
world_size = 2
rank = -1
_validate_global_rank_world_size(world_size=world_size, rank=rank)

with self.assertRaisesRegex(
ValueError, "Invalid rank and world_size values provided"
):
world_size = 8
rank = 8
_validate_global_rank_world_size(world_size=world_size, rank=rank)

def test_get_file_init_method(self) -> None:
world_size = 10
rank = 2
my_filename = "/tmp/my_filename"
init_method = get_file_init_method(
world_size=world_size, rank=rank, filename=my_filename
)
url = urlparse(init_method)
self.assertEqual(url.scheme, "file")
self.assertEqual(url.netloc, "")
self.assertEqual(url.path, my_filename)
url_qs = parse_qs(url.query)
self.assertIn("world_size", url_qs)
self.assertEqual(url_qs["world_size"], [str(world_size)])
self.assertIn("rank", url_qs)
self.assertEqual(url_qs["rank"], [str(rank)])

world_size = 2
rank = 0
# get temp filename
init_method = get_file_init_method(
world_size=world_size, rank=rank, filename=None
)
url = urlparse(init_method)
self.assertEqual(url.scheme, "file")
self.assertEqual(url.netloc, "")
self.assertNotEqual(url.path, "")
self.assertFalse(os.path.exists(url.path))
url_qs = parse_qs(url.query)
self.assertIn("world_size", url_qs)
self.assertEqual(url_qs["world_size"], [str(world_size)])
self.assertIn("rank", url_qs)
self.assertEqual(url_qs["rank"], [str(rank)])

world_size = 1
rank = 0
# get temp filename (default)
init_method = get_file_init_method(world_size=world_size, rank=rank)
url = urlparse(init_method)
self.assertEqual(url.scheme, "file")
self.assertEqual(url.netloc, "")
self.assertNotEqual(url.path, "")
self.assertFalse(os.path.exists(url.path))
url_qs = parse_qs(url.query)
self.assertIn("world_size", url_qs)
self.assertEqual(url_qs["world_size"], [str(world_size)])
self.assertIn("rank", url_qs)
self.assertEqual(url_qs["rank"], [str(rank)])

def test_get_tcp_init_method(self) -> None:
world_size = 10
rank = 2
my_hostname = "my_hostname"
my_port = 1234
init_method = get_tcp_init_method(
world_size=world_size, rank=rank, hostname=my_hostname, port=my_port
)
url = urlparse(init_method)
self.assertEqual(url.scheme, "tcp")
self.assertEqual(url.hostname, my_hostname)
self.assertEqual(url.port, my_port)
self.assertEqual(url.path, "")
url_qs = parse_qs(url.query)
self.assertIn("world_size", url_qs)
self.assertEqual(url_qs["world_size"], [str(world_size)])
self.assertIn("rank", url_qs)
self.assertEqual(url_qs["rank"], [str(rank)])

world_size = 2
rank = 1
my_hostname = "my_hostname"
# get free port
init_method = get_tcp_init_method(
world_size=world_size, rank=rank, hostname=my_hostname, port=None
)
url = urlparse(init_method)
self.assertEqual(url.scheme, "tcp")
self.assertEqual(url.hostname, my_hostname)
self.assertIsNotNone(url.port)
self.assertTrue(none_throws(url.port) > 0)
self.assertEqual(url.path, "")
url_qs = parse_qs(url.query)
self.assertIn("world_size", url_qs)
self.assertEqual(url_qs["world_size"], [str(world_size)])
self.assertIn("rank", url_qs)
self.assertEqual(url_qs["rank"], [str(rank)])

world_size = 12
rank = 7
my_port = 4321
# get localhost
init_method = get_tcp_init_method(
world_size=world_size, rank=rank, hostname=None, port=my_port
)
url = urlparse(init_method)
self.assertEqual(url.scheme, "tcp")
self.assertIsNotNone(url.hostname)
self.assertTrue(none_throws(url.hostname).startswith("localhost"))
self.assertEqual(url.port, my_port)
self.assertEqual(url.path, "")
url_qs = parse_qs(url.query)
self.assertIn("world_size", url_qs)
self.assertEqual(url_qs["world_size"], [str(world_size)])
self.assertIn("rank", url_qs)
self.assertEqual(url_qs["rank"], [str(rank)])

world_size = 128
rank = 43
# get localhost and free port
init_method = get_tcp_init_method(world_size=world_size, rank=rank)
url = urlparse(init_method)
self.assertEqual(url.scheme, "tcp")
self.assertIsNotNone(url.hostname)
self.assertTrue(none_throws(url.hostname).startswith("localhost"))
self.assertIsNotNone(url.port)
self.assertTrue(none_throws(url.port) > 0)
self.assertEqual(url.path, "")
url_qs = parse_qs(url.query)
self.assertIn("world_size", url_qs)
self.assertEqual(url_qs["world_size"], [str(world_size)])
self.assertIn("rank", url_qs)
self.assertEqual(url_qs["rank"], [str(rank)])
83 changes: 83 additions & 0 deletions torchtnt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
# pyre-ignore-all-errors[2]: Parameter must have a type that does not contain `Any`

import os
import socket
import tempfile
from contextlib import closing
from functools import wraps
from typing import Any, Callable, List, Optional, TypeVar, Union

Expand Down Expand Up @@ -170,6 +173,86 @@ def get_process_group_backend_from_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"


def _get_free_port() -> int:
if socket.has_ipv6:
family = socket.AF_INET6
address = "localhost6"
else:
family = socket.AF_INET
address = "localhost4"
with socket.socket(family, socket.SOCK_STREAM) as s:
s.bind((address, 0))
s.listen(0)
with closing(s):
sockname = s.getsockname()
port_port = sockname[1]
return port_port


def _validate_global_rank_world_size(world_size: int, rank: int) -> None:
if world_size < 1:
raise ValueError(
f"Invalid world_size value provided: {world_size}. Value must be greater than 0."
)
if rank < 0:
raise ValueError(
f"Invalid rank value provided: {rank}. Value must be greater than non-negative."
)
if rank >= world_size:
raise ValueError(
f"Invalid rank and world_size values provided: rank={rank}, world_size={world_size}. Rank must be less than world_size."
)


def get_file_init_method(
*,
world_size: Optional[int] = None,
rank: Optional[int] = None,
filename: Optional[str] = None,
) -> str:
"""Gets init method for the TCP protocol for the distributed environment.
For more information, see here: https://pytorch.org/docs/stable/distributed.html#shared-file-system-initialization
Args:
world_size: global number of workers. If ``None``, the default is fetched using :function:`get_world_size`.
rank: Global rank of the worker calling the function. If ``None``, the default is fetched using :function:`get_global_rank`.
filename: The filename to use for synchronization. If ``None``, a new temporary file is used.
"""
world_size = world_size if world_size is not None else get_world_size()
rank = rank if rank is not None else get_global_rank()
_validate_global_rank_world_size(world_size, rank)
if filename is None:
with tempfile.NamedTemporaryFile() as tmp_file:
filename = tmp_file.name
init_method = f"file://{filename}?world_size={world_size}&rank={rank}"
return init_method


def get_tcp_init_method(
*,
world_size: Optional[int] = None,
rank: Optional[int] = None,
hostname: Optional[str] = None,
port: Optional[int] = None,
) -> str:
"""Gets init method for the TCP protocol for the distributed environment.
For more information, see here: https://pytorch.org/docs/stable/distributed.html#tcp-initialization.
Args:
world_size: global number of workers. If ``None``, the default is fetched using :function:`get_world_size`.
rank: Global rank of the worker calling the function. If ``None``, the default is fetched using :function:`get_global_rank`.
hostname: an address that belongs to the rank 0 process. If ``None``, then ``localhost`` is used.
port: A free port to use for communication. If ``None``, this port is automatically selected.
"""
world_size = world_size if world_size is not None else get_world_size()
rank = rank if rank is not None else get_global_rank()
_validate_global_rank_world_size(world_size, rank)
host_addr = hostname if hostname is not None else "localhost"
host_port = port if port is not None else _get_free_port()
init_method = f"tcp://{host_addr}:{host_port}?world_size={world_size}&rank={rank}"
return init_method


def _simple_all_gather_tensors(
result: Tensor, group: Optional[dist.ProcessGroup], world_size: int
) -> List[Tensor]:
Expand Down
18 changes: 16 additions & 2 deletions torchtnt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from numpy import random
from torch.distributed.constants import default_pg_timeout
from torchtnt.utils.device import get_device_from_env, maybe_enable_tf32
from torchtnt.utils.distributed import get_process_group_backend_from_device
from torchtnt.utils.distributed import (
get_file_init_method,
get_process_group_backend_from_device,
get_tcp_init_method,
)

_log: logging.Logger = logging.getLogger(__name__)

Expand All @@ -39,6 +43,7 @@ def _check_dist_env() -> bool:
def init_from_env(
*,
device_type: T.Optional[str] = None,
dist_init_method_type: T.Literal["env", "tcp", "file"] = "env",
pg_backend: T.Optional[str] = None,
pg_timeout: timedelta = default_pg_timeout,
float32_matmul_precision: str = "high",
Expand All @@ -56,6 +61,8 @@ def init_from_env(
Args:
device_type (str, optional): Device type to initialize. If None, device will be initialized
based on environment
dist_init_method_type (str, optional): Method to initialize the process group. Must be one of "env", "tcp", or "file".
For more information, see here: https://pytorch.org/docs/stable/distributed.html#initialization
pg_backend (str, optional): The process group backend to use. If None, it will use the
default process group backend from the device
pg_timeout (timedelta, optional): Timeout for operations executed against the process
Expand Down Expand Up @@ -88,7 +95,14 @@ def init_from_env(
if pg_backend is not None
else get_process_group_backend_from_device(device)
)
torch.distributed.init_process_group(backend=pg_backend, timeout=pg_timeout)
init_method: Optional[str] = None
if dist_init_method_type == "tcp":
init_method = get_tcp_init_method()
elif dist_init_method_type == "file":
init_method = get_file_init_method()
torch.distributed.init_process_group(
init_method=init_method, backend=pg_backend, timeout=pg_timeout
)
maybe_enable_tf32(float32_matmul_precision)
return device

Expand Down

0 comments on commit 9b3b7b1

Please sign in to comment.