Skip to content

Commit

Permalink
update run_transformer tests to default to using pytorch native UCC i…
Browse files Browse the repository at this point in the history
…nstead of torch_ucc (NVIDIA#1495)

* update HAS_TORCH_UCC to TORCH_UCC

* add comments for failing tests

* move HAS_UCC to _ucc_utils.py

* whitespace

* small changes

* newline

* updated list of failing tests

* update failing tests list
  • Loading branch information
Fuzzkatt authored Oct 27, 2022
1 parent f683961 commit 0d06a73
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 41 deletions.
9 changes: 9 additions & 0 deletions apex/transformer/_ucc_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from torch import distributed as dist

HAS_UCC = dist.is_ucc_available()
if not HAS_UCC:
try:
import torch_ucc
HAS_UCC = True
except ImportError:
HAS_UCC = False
13 changes: 3 additions & 10 deletions apex/transformer/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch

from apex.transformer.log_util import get_transformer_logger
from apex.transformer._ucc_util import HAS_UCC


_logger = get_transformer_logger(__name__)
Expand Down Expand Up @@ -126,7 +127,8 @@ def initialize_model_parallel(
assert default_backend is None or default_backend in ("nccl", "ucc")
assert p2p_backend is None or p2p_backend in ("nccl", "ucc")
if "ucc" in (default_backend, p2p_backend):
check_torch_ucc_availability()
if not HAS_UCC:
raise ImportError("UCC backend requires pytorch source build with UCC installed and enabled")
warnings.warn("`ucc` backend support is experimental", ExperimentalWarning)
if default_backend == "ucc":
warnings.warn("The UCC's functionality as `default_backend` is not well verified", ExperimentalWarning)
Expand Down Expand Up @@ -671,12 +673,3 @@ def destroy_model_parallel():

# Used to warn when the UCC is specified.
class ExperimentalWarning(Warning): pass


def check_torch_ucc_availability() -> None:
try:
import torch_ucc # NOQA
except ImportError:
raise ImportError(
"UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"
)
5 changes: 3 additions & 2 deletions apex/transformer/testing/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
Batch,
)
from apex.transformer.testing import global_vars

from apex.transformer._ucc_util import HAS_UCC

TEST_SUCCESS_MESSAGE = ">> passed the test :-)"

Expand Down Expand Up @@ -257,7 +257,8 @@ def initialize_distributed(backend="nccl"):
if backend not in ("nccl", "ucc"):
raise RuntimeError(f"Currently only nccl & ucc are supported but {backend}")
if backend == "ucc":
import torch_ucc # NOQA
if not HAS_UCC:
raise ImportError("UCC backend requires pytorch source build with UCC installed and enabled")
args = global_vars.get_args()
local_rank = args.local_rank

Expand Down
13 changes: 4 additions & 9 deletions apex/transformer/testing/distributed_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from torch.testing._internal import common_utils
from torch.testing._internal import common_distributed

HAS_TORCH_UCC = None
try:
import torch_ucc
HAS_TORCH_UCC = True
except ImportError:
HAS_TORCH_UCC = False
from apex.transformer._ucc_util import HAS_UCC

# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version("470.42.01")
Expand Down Expand Up @@ -86,16 +81,16 @@ class NcclDistributedTestBase(DistributedTestBase):

DISTRIBUTED_BACKEND = "nccl"


@unittest.skipUnless(
HAS_TORCH_UCC,
"Requires [`torch_ucc`](https://github.com/facebookresearch/torch_ucc)",
HAS_UCC,
"Requires either torch ucc or pytorch build from source with native ucc installed and enabled",
)
@unittest.skipUnless(
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER,
f"`torch_ucc` requires NVIDIA driver >= {_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION} but {_driver_version} found. "
"See https://github.com/openucx/ucc/issues/496",
)

class UccDistributedTestBase(DistributedTestBase):

DISTRIBUTED_BACKEND = "ucc"
Expand Down
12 changes: 3 additions & 9 deletions tests/L0/run_transformer/run_bert_minimal_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
import random
import torch
try:
import torch_ucc
except ImportError:
HAS_TORCH_UCC = False
else:
HAS_TORCH_UCC = True
print("Use UCC as backend of Pipeline Parallel ProcessGroups")

from apex.transformer._ucc_util import HAS_UCC
from apex.transformer.enums import ModelType
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state
Expand Down Expand Up @@ -187,7 +181,7 @@ def train(
init = True
try:
virtual_pipeline_model_parallel_sizes = (None, 2,)
if HAS_TORCH_UCC:
if HAS_UCC:
# Deliberately skipping test with interleaved schedule for BERT model.
# It deadlocks on hybrid UCC/NCCL backend.
virtual_pipeline_model_parallel_sizes = (None,)
Expand Down Expand Up @@ -217,7 +211,7 @@ def train(
args.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
default_backend="nccl",
p2p_backend="ucc" if HAS_TORCH_UCC else "nccl",
p2p_backend="ucc" if HAS_UCC else "nccl",
)
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
Expand Down
10 changes: 2 additions & 8 deletions tests/L0/run_transformer/run_gpt_minimal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,8 @@
import time

import torch
try:
import torch_ucc
except ImportError:
HAS_TORCH_UCC = False
else:
HAS_TORCH_UCC = True
print("Use UCC as backend of Pipeline Parallel ProcessGroups")

from apex.transformer._ucc_util import HAS_UCC
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
Expand Down Expand Up @@ -198,7 +192,7 @@ def train(model, optim, pipeline_model_parallel_size, async_comm):
tensor_model_parallel_size_=args.tensor_model_parallel_size,
pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
default_backend="nccl",
p2p_backend="ucc" if HAS_TORCH_UCC else "nccl",
p2p_backend="ucc" if HAS_UCC else "nccl",
)

pipeline_model_parallel_size = (
Expand Down
1 change: 1 addition & 0 deletions tests/L0/run_transformer/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def test_row_parallel_linear_gradient_accumulation_fusion(self) -> None:
def test_row_parallel_linear_gradient_accumulation_fusion_in_fp16(self) -> None:
self._row_parallel_linear_test_impl(True, True, False)

# fails on native ucc and torch ucc: ucc does not support reduce scatter
@unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >=2 GPUs")
def test_row_parallel_linear_sequence_parallel(self) -> None:
self._row_parallel_linear_test_impl(False, False, True)
Expand Down
9 changes: 6 additions & 3 deletions tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@
)
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC
from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER
from apex.transformer.testing import commons as testing_utils

from apex.transformer._ucc_util import HAS_UCC

logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)
Expand Down Expand Up @@ -270,27 +269,31 @@ def test_inference_async_pipelining_without_interleaving(self, sync_batch_comm:
sync_batch_comm=sync_batch_comm,
)

# fails on native ucc: times out
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_pipelining_with_interleaving(self, sync_batch_comm: bool = True):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2,
sync_batch_comm=sync_batch_comm,
)

# fails on native ucc: times out
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_pipelining_with_interleaving(self, sync_batch_comm: bool = True):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2,
sync_batch_comm=sync_batch_comm,
)

# fails on native ucc: times out
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_async_pipelining_with_interleaving(self, sync_batch_comm: bool = True):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True,
sync_batch_comm=sync_batch_comm,
)

# fails on native ucc: times out
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_async_pipelining_with_interleaving(self, sync_batch_comm: bool = True):
self._forward_backward_test_impl(
Expand All @@ -313,7 +316,7 @@ def _run_hybrid_distributed_backend(self, forward_only: bool) -> None:

@unittest.skipUnless(HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER, "Needs driver >= 470.42.01")
def _test_hybrid_backends(self, forward_only: bool) -> None:
if HAS_TORCH_UCC:
if HAS_UCC:
self._run_hybrid_distributed_backend(forward_only)
else:
with self.assertRaisesRegex(
Expand Down

0 comments on commit 0d06a73

Please sign in to comment.