Skip to content

Commit

Permalink
c10d: retry dns lookup failures (pytorch#74641)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#74641

This makes dns hostname lookup failures retryable since in some environments such as Kubernetes they're not guaranteed to be resolvable until the job starts. Retrying this eliminates the race condition.

This also fixes `sandcastle_skip_if` when used on the class instead of the method. Previously they wouldn't inherit from TestCase so just wouldn't run under buck at all.

Fixes pytorch#73682

Test Plan:
Added a unit test

```
buck test //caffe2/test/distributed:test_store
```

Reviewed By: aivanou

Differential Revision: D35092284

fbshipit-source-id: d40bf187e52c41f551e4fe41c536b2b0015588ee
(cherry picked from commit f890830)
  • Loading branch information
d4l3k authored and pytorchmergebot committed Mar 24, 2022
1 parent d0adb5f commit 5b915e8
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 39 deletions.
8 changes: 8 additions & 0 deletions test/distributed/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,14 @@ def test_common_errors(self):
gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0")
next(gen)

def test_dns_timeout(self):
with self.assertRaisesRegex(TimeoutError, "client socket has timed out after.*dnsnotexist"):
gen = dist.rendezvous(
"tcp://dnsnotexist:23456?world_size=2&rank=0",
timeout=timedelta(seconds=1),
)
next(gen)

@retry_on_connect_failures
def test_nominal(self):
url = self.create_tcp_url()
Expand Down
52 changes: 26 additions & 26 deletions torch/csrc/distributed/c10d/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,28 +613,11 @@ std::unique_ptr<SocketImpl> SocketConnectOp::run() {
}

bool SocketConnectOp::tryConnect(int family) {
::addrinfo hints{}, *naked_result = nullptr;

::addrinfo hints{};
hints.ai_flags = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV;
hints.ai_family = family;
hints.ai_socktype = SOCK_STREAM;

int r = ::getaddrinfo(host_, port_.c_str(), &hints, &naked_result);
if (r != 0) {
const char* gai_err = ::gai_strerror(r);

recordError("The {}network addresses of ({}, {}) cannot be retrieved (gai error: {} - {}).",
family == AF_INET ? "IPv4 " : family == AF_INET6 ? "IPv6 " : "",
host_,
port_,
r,
gai_err);

return false;
}

addrinfo_ptr result{naked_result};

deadline_ = Clock::now() + opts_->connect_timeout();

std::size_t retry_attempt = 1;
Expand All @@ -645,16 +628,33 @@ bool SocketConnectOp::tryConnect(int family) {

errors_.clear();

for (::addrinfo* addr = naked_result; addr != nullptr; addr = addr->ai_next) {
C10D_TRACE("The client socket is attempting to connect to {}.", *addr);
::addrinfo *naked_result = nullptr;
// patternlint-disable cpp-dns-deps
int r = ::getaddrinfo(host_, port_.c_str(), &hints, &naked_result);
if (r != 0) {
const char* gai_err = ::gai_strerror(r);

recordError("The {}network addresses of ({}, {}) cannot be retrieved (gai error: {} - {}).",
family == AF_INET ? "IPv4 " : family == AF_INET6 ? "IPv6 " : "",
host_,
port_,
r,
gai_err);
retry = true;
} else {
addrinfo_ptr result{naked_result};

for (::addrinfo* addr = naked_result; addr != nullptr; addr = addr->ai_next) {
C10D_TRACE("The client socket is attempting to connect to {}.", *addr);

ConnectResult cr = tryConnect(*addr);
if (cr == ConnectResult::Success) {
return true;
}
ConnectResult cr = tryConnect(*addr);
if (cr == ConnectResult::Success) {
return true;
}

if (cr == ConnectResult::Retry) {
retry = true;
if (cr == ConnectResult::Retry) {
retry = true;
}
}
}

Expand Down
23 changes: 10 additions & 13 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3109,20 +3109,17 @@ def sandcastle_skip_if(condition, reason):
skipping continuously.
"""
def decorator(func):

if not IS_SANDCASTLE and condition:
func.__unittest_skip__ = True
func.__unittest_skip_why__ = reason
return func

@wraps(func)
def wrapper(*args, **kwargs):
if condition and IS_SANDCASTLE:
print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
return
if condition:
if IS_SANDCASTLE:
@wraps(func)
def wrapper(*args, **kwargs):
print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
return wrapper
else:
return func(*args, **kwargs)
return wrapper
func.__unittest_skip__ = True
func.__unittest_skip_why__ = reason

return func

return decorator

Expand Down

0 comments on commit 5b915e8

Please sign in to comment.