Skip to content

Commit

Permalink
[Gradient Compression] Add unit tests that test default Python comm h…
Browse files Browse the repository at this point in the history
…ook implementations (pytorch#47158)

Summary:
Pull Request resolved: pytorch#47158

1. Test the default Python comm hook implementations ALLREDUCE and FP16_COMPRESS, besides an ad-hoc all-reduce implementation.
2. Typo fix.
3. Reformat default_hooks.py.
4. Publish register_comm_hook API for DDP module (This should be done in a separate diff, but got merged unintentionally.)

The new style can be used for testing any new comm hook like PowerSGD easily.
Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression pytorch#47202

ghstack-source-id: 116012600

Test Plan: buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_default_ddp_comm_hooks_nccl

Reviewed By: rohan-varma

Differential Revision: D24669639

fbshipit-source-id: 048c87084234edc2398f0ea6f01f2f083a707939
  • Loading branch information
Yi Wang authored and facebook-github-bot committed Nov 6, 2020
1 parent 873652d commit fccfe7b
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 60 deletions.
62 changes: 48 additions & 14 deletions test/distributed/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.nn.functional as F
import torch.distributed as c10d
import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default
from torch.nn.parallel import DistributedDataParallel

from torch.testing._internal.common_distributed import MultiProcessTestCase, \
Expand Down Expand Up @@ -2756,7 +2757,7 @@ def _test_accumulate_gradients_no_sync(self, num_iters=2, ddp_comm_hook=None, gr
)

if ddp_comm_hook is not None:
ddp_model._register_comm_hook(process_group, ddp_comm_hook)
ddp_model.register_comm_hook(process_group, ddp_comm_hook)

def step_model(model, input, target):
model.train()
Expand Down Expand Up @@ -3377,7 +3378,7 @@ def test_ddp_comm_hook_future_passing_cpu(self):
)

# Register DDP Communication Hook
cpu_model._register_comm_hook(None, self._simple_hook)
cpu_model.register_comm_hook(None, self._simple_hook)

# check whether the grads are equal to what then callback returns.
# without the comm_hook, result would be 0.25 * torch.ones(2, 2).
Expand All @@ -3394,7 +3395,7 @@ def _gpu_model_with_ddp_comm_hook(self, process_group, hook=None, gradient_as_bu

# Register DDP Communication Hook if defined
if hook is not None:
gpu_model._register_comm_hook(None, hook)
gpu_model.register_comm_hook(None, hook)

return gpu_model

Expand Down Expand Up @@ -3472,7 +3473,7 @@ def test_ddp_comm_hook_future_passing_gpu_nccl(self):
def _test_ddp_comm_hook_allreduce_hook_nccl(self, gradient_as_bucket_view=False):
"""
This unit test verifies whether a DDP communication hook that just calls
allreduce gives the same result result with the case of no hook registered.
allreduce gives the same result with the case of no hook registered.
Without the then callback, the future_value in reducer is no longer
a PyObject, and this unit test verifies future_value is properly checked.
"""
Expand All @@ -3489,16 +3490,37 @@ def allreduce_hook(state: object, bucket: dist._GradBucket) -> torch._C.Future:
# check whether the grads are equal to what DDP without hook would return.
self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))

def _test_default_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False):
"""
This unit test verifies whether default Python DDP communication hooks ALLREDUCE and FP16_COMPRESS
can give the same result with the case of no hook registered.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

def allreduce_hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future:
return default.allreduce_hook(process_group, bucket)

def fp16_compress_hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future:
return default.fp16_compress_hook(process_group, bucket)

for hook in [allreduce_hook, fp16_compress_hook]:
# Get GPU model with the hook registered.
gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, hook, gradient_as_bucket_view)

# check whether the grads are equal to what DDP without hook would return.
self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))

def _test_builtin_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False):
"""
This unit test verifies whether built-in DDP communication hooks ALLREDUCE and FP16_COMPRESS
can give the same result result with the case of no hook registered.
This unit test verifies whether built-in C++ DDP communication hooks ALLREDUCE and FP16_COMPRESS
can give the same result with the case of no hook registered.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

for comm_hook_type in [dist.BuiltinCommHookType.ALLREDUCE, dist.BuiltinCommHookType.FP16_COMPRESS]:
# Get GPU model with the built-in allreduce communication hook.
# Get GPU model with the built-in communication hook.
gpu_model = self._gpu_model_with_builtin_ddp_comm_hook(
process_group, comm_hook_type, gradient_as_bucket_view)

Expand All @@ -3511,6 +3533,12 @@ def _test_builtin_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False):
def test_ddp_comm_hook_allreduce_hook_nccl(self):
self._test_ddp_comm_hook_allreduce_hook_nccl()

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_default_ddp_comm_hooks_nccl(self):
self._test_default_ddp_comm_hooks_nccl()

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
Expand All @@ -3523,6 +3551,12 @@ def test_builtin_ddp_comm_hooks_nccl(self):
def test_ddp_comm_hook_allreduce_hook_nccl_grad_is_view(self):
self._test_ddp_comm_hook_allreduce_hook_nccl(gradient_as_bucket_view=True)

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_default_ddp_comm_hooks_nccl_is_view(self):
self._test_default_ddp_comm_hooks_nccl(gradient_as_bucket_view=True)

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
Expand Down Expand Up @@ -3578,7 +3612,7 @@ def test_ddp_invalid_comm_hook_init(self):
model = DistributedDataParallel(ModuleForDdpCommHook(), process_group=process_group)

with self.assertRaisesRegex(TypeError, "Communication hook must be callable."):
model._register_comm_hook(state=None, hook=1)
model.register_comm_hook(state=None, hook=1)

with self.assertRaisesRegex(
ValueError, "bucket annotation should be dist._GradBucket."
Expand All @@ -3587,7 +3621,7 @@ def test_ddp_invalid_comm_hook_init(self):
def comm_hook(state: object, bucket: int) -> torch.futures.Future:
return torch.futures.Future()

model._register_comm_hook(state=None, hook=comm_hook)
model.register_comm_hook(state=None, hook=comm_hook)

@requires_gloo()
def test_ddp_invalid_comm_hook_return_type(self):
Expand All @@ -3609,7 +3643,7 @@ def test_ddp_invalid_comm_hook_return_type(self):
def comm_hook(state: object, bucket: dist._GradBucket) -> int:
return torch.futures.Future()

model._register_comm_hook(state=None, hook=comm_hook)
model.register_comm_hook(state=None, hook=comm_hook)

with self.assertRaisesRegex(
RuntimeError,
Expand All @@ -3619,7 +3653,7 @@ def comm_hook(state: object, bucket: dist._GradBucket) -> int:
def comm_hook(state: object, bucket: dist._GradBucket):
return 1

model._register_comm_hook(state=None, hook=comm_hook)
model.register_comm_hook(state=None, hook=comm_hook)

# Run forward
output = model(8, self.rank)
Expand All @@ -3643,12 +3677,12 @@ def dummy_hook(state, bucket):
fut.set_result(bucket.get_tensors())
return fut

model._register_comm_hook(None, dummy_hook)
model.register_comm_hook(None, dummy_hook)

with self.assertRaisesRegex(
RuntimeError, "register_comm_hook or register_builtin_comm_hook can only be called once."
):
model._register_comm_hook(None, dummy_hook)
model.register_comm_hook(None, dummy_hook)

@requires_gloo()
def test_ddp_comm_hook_sparse_gradients(self):
Expand Down Expand Up @@ -3679,7 +3713,7 @@ def allreduce_hook_gloo(state: object, bucket: dist._GradBucket) -> torch.future
fut.set_result([t / self.world_size for t in bucket.get_tensors()])
return fut

ddp_model._register_comm_hook(None, allreduce_hook_gloo)
ddp_model.register_comm_hook(None, allreduce_hook_gloo)

self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)

Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1100,12 +1100,12 @@ that adds a prefix to each key inserted to the store.
>>> work = process_group.allreduce(tensors)
>>> return work.get_future()
>>> ddp_model._register_comm_hook(state = None, hook = allreduce)
>>> ddp_model._egister_comm_hook(state = None, hook = allreduce)
.. warning ::
``get_future`` API supports only NCCL backend and single-process single-device mode.
The ``torch._C.Future`` object returned by this API can be used in
``DistributedDataParallel._register_comm_hook``, but it is subject to some subtle
``DistributedDataParallel.register_comm_hook``, but it is subject to some subtle
differences compared to ``torch.futures.Future`` due to compromises made for performance
reasons.
Expand Down
Empty file.
6 changes: 3 additions & 3 deletions torch/distributed/algorithms/ddp_comm_hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from functools import partial

import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default
import torch.distributed.algorithms.ddp_comm_hooks.quantization_hooks as quantization
from . import default_hooks as default
from . import quantization_hooks as quantization
from torch.nn.parallel import DistributedDataParallel


def _ddp_comm_hook_wrapper(comm_hook, model, state):
model._register_comm_hook(state, comm_hook)
model.register_comm_hook(state, comm_hook)


class DDPCommHookType(Enum):
Expand Down
72 changes: 37 additions & 35 deletions torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ def allreduce_hook(
process_group: object, bucket: dist._GradBucket
) -> torch.futures.Future:
"""
This DDP communication hook just calls ``allreduce`` using ``GradBucket``
tensors. Once gradient tensors are aggregated across all workers, its ``then``
callback takes the mean and returns the result. If user registers this hook,
DDP results is expected to be same as the case where no hook was registered.
Hence, this won't change behavior of DDP and user can use this as a reference
or modify this hook to log useful information or any other purposes while
unaffecting DDP behavior.
Example::
>>> ddp_model._register_comm_hook(process_group, allreduce_hook)
This DDP communication hook just calls ``allreduce`` using ``GradBucket``
tensors. Once gradient tensors are aggregated across all workers, its ``then``
callback takes the mean and returns the result. If user registers this hook,
DDP results is expected to be same as the case where no hook was registered.
Hence, this won't change behavior of DDP and user can use this as a reference
or modify this hook to log useful information or any other purposes while
unaffecting DDP behavior.
Example::
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
Expand All @@ -31,17 +31,19 @@ def then_callback(fut):
return fut.then(then_callback)


def fp16_compress_hook(process_group: object, bucket: dist._GradBucket):
def fp16_compress_hook(
process_group: object, bucket: dist._GradBucket
) -> torch.futures.Future:
"""
This DDP communication hook implements a simple gradient compression
approach that converts ``GradBucket`` tensors whose type is assumed to be
``torch.float32`` to half-precision floating point format (``torch.float16``).
It allreduces those ``float16`` gradient tensors. Once compressed gradient
tensors are allreduced, its then callback called ``decompress`` converts the
aggregated result back to ``float32`` and takes the mean.
Example::
>>> ddp_model._register_comm_hook(process_group, fp16_compress_hook)
This DDP communication hook implements a simple gradient compression
approach that converts ``GradBucket`` tensors whose type is assumed to be
``torch.float32`` to half-precision floating point format (``torch.float16``).
It allreduces those ``float16`` gradient tensors. Once compressed gradient
tensors are allreduced, its then callback called ``decompress`` converts the
aggregated result back to ``float32`` and takes the mean.
Example::
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
Expand Down Expand Up @@ -80,21 +82,21 @@ def _allgather_then_aggregate_hook(
process_group: object, bucket: dist._GradBucket
) -> torch.futures.Future:
"""
Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors
and its ``then`` callback aggregates the gathered gradient tensors and takes
mean. Instead of ``allreduce`` this hook uses ``allgather``. Note that with
W workers, both the computation and communication time scale as O(W) for
allgather compared to O(logW) for allreduce. Therefore, this hook is expected
to be much slower than ``allreduce_hook`` although both essentially do the
same thing with the gradients.
.. warning ::
This is for test and experiments. User is suggested to use a faster
alternative called ``allreduce_hook`` that uses ``allreduce`` protocol
instead of ``allgather`` protocol.
Example::
>>> ddp_model._register_comm_hook(process_group, allreduce_hook)
Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors
and its ``then`` callback aggregates the gathered gradient tensors and takes
mean. Instead of ``allreduce`` this hook uses ``allgather``. Note that with
W workers, both the computation and communication time scale as O(W) for
allgather compared to O(logW) for allreduce. Therefore, this hook is expected
to be much slower than ``allreduce_hook`` although both essentially do the
same thing with the gradients.
.. warning ::
This is for test and experiments. User is suggested to use a faster
alternative called ``allreduce_hook`` that uses ``allreduce`` protocol
instead of ``allgather`` protocol.
Example::
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def quantization_pertensor_hook(
``allreduce`` protocol. It works only with flattened grads.
Example::
>>> ddp_model._register_comm_hook(process_group, quantization_pertensor_hook)
>>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
Expand Down Expand Up @@ -140,7 +140,7 @@ def quantization_perchannel_hook(
``allreduce`` protocol. It works only with flattened grads.
Example::
>>> ddp_model._register_comm_hook(process_group, quantization_perchannel_hook)
>>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
Expand Down
8 changes: 4 additions & 4 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ def join(self, divide_by_initial_world_size=True, enable=True):
# All procs joined. Agree on authoritative rank and broadcast the model.
self._sync_final_model(is_last_joiner)

def _register_comm_hook(self, state: object, hook: callable):
def register_comm_hook(self, state: object, hook: callable):
r"""
Registers a communication hook which is an enhancement that provides a
flexible hook to users where they can specify how DDP aggregates gradients
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def _register_comm_hook(self, state: object, hook: callable):
.. warning ::
``get_future`` API supports only NCCL backend and will return a ``torch._C.Future``
which is an internal type and should be used with caution. It can still be used by
``_register_comm_hook`` API, but it is subject to some subtle differences compared
``register_comm_hook`` API, but it is subject to some subtle differences compared
to ``torch.futures.Future``.
.. warning ::
Expand All @@ -1038,7 +1038,7 @@ def _register_comm_hook(self, state: object, hook: callable):
>>> fut.set_result(bucket.get_tensors())
>>> return fut
>>> ddp._register_comm_hook(state = None, hook = noop)
>>> ddp.register_comm_hook(state = None, hook = noop)
Example::
Below is an example of a Parallel SGD algorithm where gradients are encoded before
Expand All @@ -1054,7 +1054,7 @@ def _register_comm_hook(self, state: object, hook: callable):
>>> return decoded_tensors
>>> return fut.then(decode)
>>> ddp._register_comm_hook(state = None, hook = encode_and_decode)
>>> ddp.register_comm_hook(state = None, hook = encode_and_decode)
"""
self._check_comm_hook(hook)
Expand Down

0 comments on commit fccfe7b

Please sign in to comment.