Skip to content

Commit

Permalink
Add sync support for dict collections of metrics (#98)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #98

Added support for two new sync methods `sync_and_compute_collection` and `get_synced_state_dicts_collection`. These methods use only a single data transfer per sync rather than one per metric.

Reviewed By: ananthsub

Differential Revision: D41674853

fbshipit-source-id: d1cb8d81a13aecb6e7c9ab02839a7b105b33c87d
  • Loading branch information
bobakfb authored and facebook-github-bot committed Dec 13, 2022
1 parent 079c0f7 commit 4f78bb1
Show file tree
Hide file tree
Showing 2 changed files with 485 additions and 29 deletions.
190 changes: 189 additions & 1 deletion tests/metrics/test_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import unittest
import uuid
from typing import Callable, Type, Union
from typing import Callable, List, Type, Union

import pytest

Expand All @@ -19,9 +19,12 @@
clone_metric,
clone_metrics,
get_synced_metric,
get_synced_metric_collection,
get_synced_state_dict,
get_synced_state_dict_collection,
reset_metrics,
sync_and_compute,
sync_and_compute_collection,
to_device,
)
from torcheval.utils.test_utils.dummy_metric import (
Expand Down Expand Up @@ -227,3 +230,188 @@ def test_classwise_converter(self) -> None:
"Number of labels [0-9]+ must be equal to the number of classes [0-9]+",
):
classwise_converter(metrics, name, labels)


class MetricCollectionToolkitTest(unittest.TestCase):
@staticmethod
def _test_per_process_metric_collection_sync(
input_tensor: torch.Tensor,
metric_constructors: List[Callable[[], Metric]],
recipient_rank: Union[int, Literal["all"]],
) -> None:
device = init_from_env()
if device.type == "cuda":
torch.cuda.empty_cache()
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
tc = unittest.TestCase()

num_total_updates = len(input_tensor)
num_total_metrics = len(metric_constructors)

# ==========================
# dictwise test:
# ==========================
metric_names = [f"metric_{i}" for i in range(num_total_metrics)]
metrics = {
metric_names[i]: metric_constructors[i]().to(device)
for i in range(num_total_metrics)
}
for i in range(rank, num_total_updates, world_size):
for j in range(num_total_metrics):
metrics[metric_names[j]].update(input_tensor[i, j])

state_dicts_before_sync = {
metric_names[i]: metrics[metric_names[i]].state_dict()
for i in range(num_total_metrics)
}
synced_metric_dict = get_synced_metric_collection(
metrics, recipient_rank=recipient_rank
)
compute_result_dict = sync_and_compute_collection(
metrics, recipient_rank=recipient_rank
)
synced_state_dict_dict = get_synced_state_dict_collection(
metrics, recipient_rank=recipient_rank
)

# input metric state unchanged
tc.assertDictEqual(
{
metric_names[i]: metrics[metric_names[i]].state_dict()
for i in range(num_total_metrics)
},
state_dicts_before_sync,
)

if rank == recipient_rank or recipient_rank == "all":

# make sure we get real dicts on recipient ranks
tc.assertIsNotNone(synced_metric_dict)
tc.assertIsNotNone(compute_result_dict)
tc.assertIsNotNone(synced_state_dict_dict)

# construct new empty metrics
metrics_with_all_updates = {
metric_names[i]: metric_constructors[i]().to(device)
for i in range(num_total_metrics)
}

# update all the metrics in the collection so that we have a copy of the
# metrics with all the updates applied on one rank, bypassing syncing logic.
for i in range(num_total_updates):
for j in range(num_total_metrics):
metrics_with_all_updates[metric_names[j]].update(input_tensor[i, j])

# test that each of the metrics constructed on one rank without syncing have
# the same compute() value as the metrics updated on different ranks and synced
# on this rank.
for metric_id in metric_names:
# test through the get_synced_metric_collection route
torch.testing.assert_close(
synced_metric_dict[metric_id].compute(),
metrics_with_all_updates[metric_id].compute(),
check_device=False,
)
# test through the synced_and_compute_collection route
torch.testing.assert_close(
compute_result_dict[metric_id],
metrics_with_all_updates[metric_id].compute(),
check_device=False,
)
tc.assertGreater(len(synced_state_dict_dict[metric_id]), 0)
else:
tc.assertIsNone(synced_metric_dict)
tc.assertIsNone(compute_result_dict)
tc.assertIsNone(synced_state_dict_dict)

def _launch_collection_sync_test(
self,
num_processes: int,
input_tensor: torch.Tensor,
metric_classes: List[Type[Metric]],
recipient_rank: Union[int, str] = 0,
) -> None:
lc = pet.LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=num_processes,
run_id=str(uuid.uuid4()),
rdzv_backend="c10d",
rdzv_endpoint="localhost:0",
max_restarts=0,
monitor_interval=1,
)
pet.elastic_launch(
lc, entrypoint=self._test_per_process_metric_collection_sync
)(
input_tensor,
metric_classes,
recipient_rank,
)

def test_metric_collection_sync(self) -> None:
num_processes = torch.cuda.device_count() if torch.cuda.is_available() else 4
num_metrics = 6
num_updates = 3 * num_processes

input_tensor = torch.rand(size=(num_metrics, num_updates))
# recipient_rank = 0
self._launch_collection_sync_test(
num_processes, input_tensor, [DummySumMetric for i in range(num_metrics)]
)
self._launch_collection_sync_test(
num_processes,
input_tensor,
[DummySumListStateMetric for i in range(num_metrics)],
)

# recipient_rank = 1
self._launch_collection_sync_test(
num_processes, input_tensor, [DummySumMetric for i in range(num_metrics)], 1
)
self._launch_collection_sync_test(
num_processes,
input_tensor,
[DummySumListStateMetric for i in range(num_metrics)],
1,
)

# recipient_rank = "all"
self._launch_collection_sync_test(
num_processes,
input_tensor,
[DummySumMetric for i in range(num_metrics)],
"all",
)
self._launch_collection_sync_test(
num_processes,
input_tensor,
[DummySumListStateMetric for i in range(num_metrics)],
"all",
)

def test_metric_collection_sync_world_size_1(self) -> None:
metric_collection: dict[str, Metric] = {
"m1": DummySumMetric(),
"m2": DummySumMetric(),
}
synced_metric_collection = get_synced_metric_collection(metric_collection)
self.assertIsNotNone(synced_metric_collection)
self.assertDictEqual(synced_metric_collection, metric_collection)

state_dict_collection = {
"m1": metric_collection["m1"].state_dict(),
"m2": metric_collection["m2"].state_dict(),
}

self.assertDictEqual(
# pyre-ignore: Incompatible parameter type [6]: In call `unittest.case.TestCase.assertDictEqual`, for 1st positional only parameter expected `Mapping[typing.Any, object]` but got `Optional[Dict[str, Dict[str, typing.Any]]]`.
get_synced_state_dict_collection(metric_collection),
state_dict_collection,
)
self.assertDictEqual(
# pyre-ignore: Incompatible parameter type [6]: In call `unittest.case.TestCase.assertDictEqual`, for 1st positional only parameter expected `Mapping[typing.Any, object]` but got `Optional[Dict[str, Dict[str, typing.Any]]]`.
get_synced_state_dict_collection(metric_collection, recipient_rank="all"),
state_dict_collection,
)
Loading

0 comments on commit 4f78bb1

Please sign in to comment.