Skip to content

Commit

Permalink
Add torch.distributed.DistBackendError exception type, thrown from C1…
Browse files Browse the repository at this point in the history
…0D_NCCL_CHECK (pytorch#88134)

Currently all of the distributed errors are thrown from the `TORCH_CHECK` macro which throws a generic `RuntimeError`. This change introduced a new error type `DistBackendError` which derives from `RuntimeError` to signify there was an error with the backend communication library. This allows for better error handling and analysis at higher levels in the stack. Motivation: https://docs.google.com/document/d/1j6VPOkC6znscliFuiDWMuMV1_fH4Abgdq7TCHMcXai4/edit#heading=h.a9rc38misyx8

Changes:
- introduce new error type
- Update `C10D_NCCL_CHECK`

Sample script to demonstrate new error type

```python
# python -m torch.distributed.run --nproc_per_node=2 <script>.py

import torch
import torch.distributed as dist

if __name__ == "__main__":
    dist.init_process_group("nccl")
    dist.broadcast(torch.tensor([1, 2, 3]).cuda(), 0)
```

Differential Revision: [D40998803](https://our.internmc.facebook.com/intern/diff/D40998803)
Pull Request resolved: pytorch#88134
Approved by: https://github.com/rohan-varma
  • Loading branch information
H-Huang authored and pytorchmergebot committed Nov 8, 2022
1 parent 1a7c4b0 commit bc66ddb
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 3 deletions.
6 changes: 6 additions & 0 deletions c10/util/Exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ class C10_API OutOfMemoryError : public Error {
using Error::Error;
};

// Used for collective communication library errors from the distributed module.
// These turn into DistBackendError when they cross into Python.
class C10_API DistBackendError : public Error {
using Error::Error;
};

// A utility function to return an exception std::string by prepending its
// exception type before its what() content
C10_API std::string GetExceptionString(const std::exception& e);
Expand Down
7 changes: 7 additions & 0 deletions docs/source/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,13 @@ following matrix shows how the log level can be adjusted via the combination of
| ``INFO`` | ``DETAIL`` | Trace (a.k.a. All) |
+-------------------------+-----------------------------+------------------------+

Distributed has a custom Exception type derived from `RuntimeError` called `torch.distributed.DistBackendError`. This exception is thrown when a backend-specific error occurs. For example, if
the `NCCL` backend is used and the user attempts to use a GPU that is not available to the `NCCL` library.

.. autoclass:: torch.distributed.DistBackendError

.. warning::
The DistBackendError exception type is an experimental feature is subject to change.

.. Distributed modules that are missing specific entries.
.. Adding them here for tracking purposes until they are more permanently fixed.
Expand Down
10 changes: 10 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,16 @@ def test_send_recv(self):
with self.assertRaisesRegex(RuntimeError, 'Tensors must be contiguous'):
dist.send(send_tensor_view, 1)

@requires_nccl()
@sandcastle_skip_if(torch.cuda.device_count() < 1, "NCCL test requires 1 GPU")
@skip_if_lt_x_gpu(1)
def test_nccl_dist_backend_error(self):
store = c10d.FileStore(self.file_name, self.world_size)
self._create_process_group_nccl(store, self.opts())

# Both rank 0 and 1 will use the same CUDA device resulting in ncclInvalidUsage
with self.assertRaises(dist.DistBackendError):
dist.broadcast(torch.tensor([1, 2, 3]).cuda(), 0)

class DistributedDataParallelTest(
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
Expand Down
3 changes: 3 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1504,3 +1504,6 @@ def _current_graph_task_id() -> _int: ...

class _OutOfMemoryError:
pass

class _DistBackendError:
pass
12 changes: 11 additions & 1 deletion torch/csrc/Exceptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <c10/util/StringUtil.h>

PyObject *THPException_FatalError, *THPException_LinAlgError,
*THPException_OutOfMemoryError;
*THPException_OutOfMemoryError, *THPException_DistBackendError;

#define ASSERT_TRUE(cond) \
if (!(cond)) \
Expand Down Expand Up @@ -63,6 +63,16 @@ could not be completed because the input matrix is singular.",
PyModule_AddObject(
module, "_OutOfMemoryError", THPException_OutOfMemoryError) == 0);

ASSERT_TRUE(
THPException_DistBackendError = PyErr_NewExceptionWithDoc(
"torch.distributed.DistBackendError",
"Exception raised when a backend error occurs in distributed",
PyExc_RuntimeError,
nullptr));
ASSERT_TRUE(
PyModule_AddObject(
module, "_DistBackendError", THPException_DistBackendError) == 0);

return true;
}

Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/Exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
_CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
_CATCH_GENERIC_ERROR( \
OutOfMemoryError, THPException_OutOfMemoryError, retstmnt) \
_CATCH_GENERIC_ERROR( \
DistBackendError, THPException_DistBackendError, retstmnt) \
_CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \
catch (torch::PyTorchError & e) { \
auto msg = torch::processErrorMsg(e.what()); \
Expand Down Expand Up @@ -146,7 +148,7 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)

extern PyObject *THPException_FatalError, *THPException_LinAlgError,
*THPException_OutOfMemoryError;
*THPException_OutOfMemoryError, *THPException_DistBackendError;

// Throwing this exception means that the python error flags have been already
// set and control should be immediately returned to the interpreter.
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/NCCLUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
"\n" + getNcclErrorDetailStr(result, failureReason); \
TORCH_CHECK(false, err); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
} while (0)

Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def is_available() -> bool:
if is_available() and not torch._C._c10d_init():
raise RuntimeError("Failed to initialize torch.distributed")

# Custom Runtime Errors thrown from the distributed package
DistBackendError = torch._C._DistBackendError

if is_available():
from torch._C._distributed_c10d import (
Expand Down

0 comments on commit bc66ddb

Please sign in to comment.