Skip to content

Commit

Permalink
elastic/rendezvous: make barrier and rank assignment operations O(n) …
Browse files Browse the repository at this point in the history
…instead of O(n^2) (pytorch#124982)

Summary:
This makes barrier and rank operations linear instead of quadratic with the number of workers. This drastically improves performance for rendezvous when running with over 1000 hosts.

This uses 2 approaches for different areas:

* local rank assignment: each worker does 1 set and 1 get, local ranks are assigned on the rank 0 host in a O(n) operation which reduces total store operations to be linear with number of workers.
* exit_barrier: use a counter and a final flag so each worker has to do max 1 set, 1 get and 1 add.

At 4000 hosts we see torchelastic be able to run in as little as 10 seconds down from 373 seconds.

Test Plan:
This is testing using many small tests running on a remote cluster.

{D56549942}

```
torchx run --scheduler mast -- --image=torchelastic_benchmark --j=4000x1
```

Differential Revision: D56605193

Pull Request resolved: pytorch#124982
Approved by: https://github.com/kiukchung, https://github.com/kurman
  • Loading branch information
d4l3k authored and pytorchmergebot committed Apr 27, 2024
1 parent 1a6fef1 commit dc4c75b
Show file tree
Hide file tree
Showing 6 changed files with 328 additions and 218 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@
"barrier",
"get_all",
"synchronize",
"store_timeout",
# torch.distributed.fsdp.wrap
"always_wrap_policy",
"enable_wrap",
Expand Down
128 changes: 62 additions & 66 deletions test/distributed/elastic/agent/server/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
import signal
import unittest
import uuid
from typing import Any, Dict
from unittest.mock import call, MagicMock, patch
from multiprocessing.pool import ThreadPool
from typing import Any, Dict, List
from unittest.mock import call, patch

import torch.distributed as dist

import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch.distributed.elastic.agent.server.api import (
_get_fq_hostname,
_RoleInstanceInfo,
RunResult,
SimpleElasticAgent,
Worker,
WorkerGroup,
WorkerSpec,
WorkerState,
Expand Down Expand Up @@ -470,22 +474,6 @@ def test_run_unknown_state(self, mock_monitor_workers):
self.assertEqual(1, mock_monitor_workers.call_count)
self.assertEqual(spec.max_restarts, agent._remaining_restarts)

def test_get_ranks(self):
role_infos = [
_RoleInstanceInfo("parameter_server", 0, 4),
_RoleInstanceInfo("trainer", 1, 1),
_RoleInstanceInfo("trainer", 2, 2),
_RoleInstanceInfo("trainer", 3, 3),
_RoleInstanceInfo("parameter_server", 4, 5),
]
spec = self._get_worker_spec(
max_restarts=3, monitor_interval=0.1, role="not_used", local_world_size=8
)
agent = TestAgent(spec)
total_sum, ranks = agent._get_ranks(role_infos, 0, 0, len(role_infos))
self.assertEqual(15, total_sum)
self.assertEqual([0, 1, 2, 3], list(ranks))

def test_assign_worker_ranks(self):
role_infos = [
_RoleInstanceInfo("parameter_server", 0, 4),
Expand All @@ -494,56 +482,64 @@ def test_assign_worker_ranks(self):
_RoleInstanceInfo("trainer", 3, 3),
_RoleInstanceInfo("parameter_server", 4, 5),
]
num_agents = len(role_infos)
with patch.object(TestAgent, "_share_and_gather", return_value=role_infos):
self.verify_worker_ranks(
role_infos[0], num_agents, [0, 1, 2, 3], [0, 1, 2, 3]
store = dist.HashStore()

def f(info) -> List[Worker]:
i, role_info = info
spec = self._get_worker_spec(
max_restarts=3,
monitor_interval=0.1,
role=role_info.role,
local_world_size=role_info.local_world_size,
)
self.verify_worker_ranks(role_infos[1], num_agents, [4], [0])
self.verify_worker_ranks(role_infos[2], num_agents, [5, 6], [1, 2])
self.verify_worker_ranks(role_infos[3], num_agents, [7, 8, 9], [3, 4, 5])

def verify_worker_ranks(
self, agent_config, total_agents, expected_global_ranks, expected_role_ranks
):
role, agent_rank, local_world_size = (
agent_config.role,
agent_config.rank,
agent_config.local_world_size,
)
spec = self._get_worker_spec(
max_restarts=3,
monitor_interval=0.1,
role=role,
local_world_size=local_world_size,
)
agent = TestAgent(spec)
workers = agent._assign_worker_ranks(None, agent_rank, total_agents, spec)
self.assertEqual(
expected_global_ranks, [worker.global_rank for worker in workers]
)
self.assertEqual(expected_role_ranks, [worker.role_rank for worker in workers])

@patch("torch.distributed.elastic.utils.store.synchronize")
def test_share_and_gather(self, sync_mock):
# when the state is unknown we exit immediately; no retries
spec = self._get_worker_spec(max_restarts=100, monitor_interval=0.1)
agent = TestAgent(spec)
expected_agent_infos = [
_RoleInstanceInfo("trainer", 0, 10),
_RoleInstanceInfo("trainer", 1, 10),
_RoleInstanceInfo("validator", 2, 10),
]

sync_mock.return_value = [obj.serialize() for obj in expected_agent_infos]
result = agent._share_and_gather(MagicMock(), 1, 3, spec)
sync_mock.assert_called_once()
for expected_role_info, actual_role_info in zip(expected_agent_infos, result):
self.assertEqual(expected_role_info.role, actual_role_info.role)
self.assertEqual(expected_role_info.rank, actual_role_info.rank)
self.assertEqual(
expected_role_info.local_world_size, actual_role_info.local_world_size
agent = TestAgent(spec)
workers = agent._assign_worker_ranks(
store, role_info.rank, len(role_infos), spec
)
return [
(
w.local_rank,
w.role_rank,
w.global_rank,
w.world_size,
w.role_world_size,
)
for w in workers
]

with ThreadPool(len(role_infos)) as pool:
out = pool.map(f, enumerate(role_infos))

self.assertListEqual(
out,
[
[
(0, 0, 0, 15, 9),
(1, 1, 1, 15, 9),
(2, 2, 2, 15, 9),
(3, 3, 3, 15, 9),
],
[
(0, 0, 4, 15, 6),
],
[
(0, 1, 5, 15, 6),
(1, 2, 6, 15, 6),
],
[
(0, 3, 7, 15, 6),
(1, 4, 8, 15, 6),
(2, 5, 9, 15, 6),
],
[
(0, 4, 10, 15, 9),
(1, 5, 11, 15, 9),
(2, 6, 12, 15, 9),
(3, 7, 13, 15, 9),
(4, 8, 14, 15, 9),
],
],
)

def test_get_event(self):
spec = self._get_worker_spec(max_restarts=1)
Expand Down
173 changes: 119 additions & 54 deletions test/distributed/elastic/utils/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,77 +7,142 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from unittest import mock
import datetime
from multiprocessing.pool import ThreadPool
from typing import List

import torch.distributed as dist

import torch.distributed.elastic.utils.store as store_util
from torch.distributed.elastic.utils.logging import get_logger
from torch.testing._internal.common_utils import run_tests, TestCase


class MockStore:
def __init__(self):
self.ops = []

def set_timeout(self, timeout: float) -> None:
self.ops.append(("set_timeout", timeout))

@property
def timeout(self) -> datetime.timedelta:
self.ops.append(("timeout",))

return datetime.timedelta(seconds=1234)

def set(self, key: str, value: str) -> None:
self.ops.append(("set", key, value))

def get(self, key: str) -> str:
self.ops.append(("get", key))
return "value"

def multi_get(self, keys: List[str]) -> List[str]:
self.ops.append(("multi_get", keys))
return ["value"] * len(keys)

def add(self, key: str, val: int) -> int:
self.ops.append(("add", key, val))
return 3


class StoreUtilTest(TestCase):
def test_get_all_rank_0(self):
store = mock.MagicMock()
world_size = 3

store = MockStore()

store_util.get_all(store, 0, "test/store", world_size)
# omit empty kwargs, get only key
actual_set_call_args = [
call_args[0][0] for call_args in store.set.call_args_list
]
self.assertListEqual(["test/store0.FIN"], actual_set_call_args)

actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list]
expected_get_call_args = [
("test/store0",),
("test/store1",),
("test/store2",),
("test/store0.FIN",),
("test/store1.FIN",),
("test/store2.FIN",),
]
self.assertListEqual(expected_get_call_args, actual_get_call_args)

self.assertListEqual(
store.ops,
[
("multi_get", ["test/store0", "test/store1", "test/store2"]),
("add", "test/store/finished/num_members", 1),
("set", "test/store/finished/last_member", "<val_ignored>"),
("get", "test/store/finished/last_member"),
],
)

def test_get_all_rank_n(self):
store = mock.MagicMock()
store = MockStore()
world_size = 3
store_util.get_all(store, 1, "test/store", world_size)
# omit empty kwargs, get only key
actual_set_call_args = [
call_args[0][0] for call_args in store.set.call_args_list
]
self.assertListEqual(["test/store1.FIN"], actual_set_call_args)

actual_get_call_args = [call_args[0] for call_args in store.get.call_args_list]
expected_get_call_args = [
("test/store0",),
("test/store1",),
("test/store2",),
]
self.assertListEqual(expected_get_call_args, actual_get_call_args)

self.assertListEqual(
store.ops,
[
("multi_get", ["test/store0", "test/store1", "test/store2"]),
("add", "test/store/finished/num_members", 1),
("set", "test/store/finished/last_member", "<val_ignored>"),
],
)

def test_synchronize(self):
store_mock = mock.MagicMock()
store = MockStore()

data = b"data0"
store_util.synchronize(store_mock, data, 0, 3, key_prefix="torchelastic/test")
actual_set_call_args = store_mock.set.call_args_list
# omit empty kwargs
actual_set_call_args = [call_args[0] for call_args in actual_set_call_args]
expected_set_call_args = [
("torchelastic/test0", b"data0"),
("torchelastic/test0.FIN", b"FIN"),
]
self.assertListEqual(expected_set_call_args, actual_set_call_args)

expected_get_call_args = [
("torchelastic/test0",),
("torchelastic/test1",),
("torchelastic/test2",),
("torchelastic/test0.FIN",),
("torchelastic/test1.FIN",),
("torchelastic/test2.FIN",),
]
actual_get_call_args = store_mock.get.call_args_list
actual_get_call_args = [call_args[0] for call_args in actual_get_call_args]
self.assertListEqual(expected_get_call_args, actual_get_call_args)
store_util.synchronize(store, data, 0, 3, key_prefix="test/store")

self.assertListEqual(
store.ops,
[
("timeout",),
("set_timeout", datetime.timedelta(seconds=300)),
("set", "test/store0", data),
("multi_get", ["test/store0", "test/store1", "test/store2"]),
("add", "test/store/finished/num_members", 1),
("set", "test/store/finished/last_member", "<val_ignored>"),
("get", "test/store/finished/last_member"),
("set_timeout", datetime.timedelta(seconds=1234)),
],
)

def test_synchronize_hash_store(self) -> None:
N = 4

store = dist.HashStore()

def f(i: int):
return store_util.synchronize(
store, f"data{i}", i, N, key_prefix="test/store"
)

with ThreadPool(N) as pool:
out = pool.map(f, range(N))

self.assertListEqual(out, [[f"data{i}".encode() for i in range(N)]] * N)

def test_barrier(self):
store = MockStore()

store_util.barrier(store, 3, key_prefix="test/store")

self.assertListEqual(
store.ops,
[
("timeout",),
("set_timeout", datetime.timedelta(seconds=300)),
("add", "test/store/num_members", 1),
("set", "test/store/last_member", "<val_ignored>"),
("get", "test/store/last_member"),
("set_timeout", datetime.timedelta(seconds=1234)),
],
)

def test_barrier_hash_store(self) -> None:
N = 4

store = dist.HashStore()

def f(i: int):
store_util.barrier(store, N, key_prefix="test/store")

with ThreadPool(N) as pool:
out = pool.map(f, range(N))

self.assertEqual(out, [None] * N)


class UtilTest(TestCase):
Expand Down
Loading

0 comments on commit dc4c75b

Please sign in to comment.