Skip to content

Commit

Permalink
[Client] avoid locking in async send (ray-project#22193)
Browse files Browse the repository at this point in the history
As @iycheng discovered in ray-project#22082 (comment), when `ClientObjectRef` is being GC'ed, `DataClient.lock` is acquired which may cause deadlock. This change avoids acquiring lock in `DataClient._async_send()`.
  • Loading branch information
mwtian authored Feb 9, 2022
1 parent 20ab918 commit 71f6359
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 63 deletions.
2 changes: 1 addition & 1 deletion python/ray/tests/test_dataclient_disconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def f():
# Force grpc to error by queueing garbage request. This simulates
# the data channel shutting down for connection issues between
# different remote calls.
ray.worker.data_client.request_queue.put(Mock())
ray.worker.data_client.request_queue.put((Mock(), None))

# The following two assertions are relatively brittle. Consider a more
# robust mechanism if they fail with code changes or become flaky.
Expand Down
18 changes: 11 additions & 7 deletions python/ray/util/client/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def __init__(self, id: Union[bytes, Future]):
else:
raise TypeError("Unexpected type for id {}".format(id))

# NOTE: synchronization primitives like threading.Lock should not be used
# transitively by a destructor. Otherwise deadlocks can happen. See
# https://stackoverflow.com/questions/18774401/self-deadlock-due-to-garbage-collector-in-single-threaded-code
def __del__(self):
if self._worker is not None and self._worker.is_connected():
try:
Expand All @@ -106,8 +109,7 @@ def __del__(self):
logger.info(
"Exception in ObjectRef is ignored in destructor. "
"To receive this exception in application code, call "
"a method on the actor reference before its destructor "
"is run."
"a method on the reference before its destructor runs."
)

def binary(self):
Expand Down Expand Up @@ -202,6 +204,9 @@ def __init__(self, id: Union[bytes, Future]):
else:
raise TypeError("Unexpected type for id {}".format(id))

# NOTE: synchronization primitives like threading.Lock should not be used
# transitively by a destructor. Otherwise deadlocks can happen. See
# https://stackoverflow.com/questions/18774401/self-deadlock-due-to-garbage-collector-in-single-threaded-code
def __del__(self):
if self._worker is not None and self._worker.is_connected():
try:
Expand All @@ -211,8 +216,7 @@ def __del__(self):
logger.info(
"Exception from actor creation is ignored in destructor. "
"To receive this exception in application code, call "
"a method on the actor reference before its destructor "
"is run."
"a method on the reference before its destructor runs."
)

def binary(self):
Expand Down Expand Up @@ -853,10 +857,10 @@ def check_cache(self, req_id: int) -> Optional[Any]:
# Request is for an id that has already been cleared from
# cache/acknowledged.
raise RuntimeError(
"Attempting to accesss a cache entry that has already "
"Attempting to access a cache entry that has already "
"cleaned up. The client has already acknowledged "
f"receiving this response. ({req_id}, "
f"{self.last_received})"
f"receiving this response. (this_req_id={req_id}, "
f"acknowledged={self.last_received})"
)
if req_id in self.cache:
cached_resp = self.cache[req_id]
Expand Down
160 changes: 112 additions & 48 deletions python/ray/util/client/dataclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import queue
import threading
import traceback

import grpc

from collections import OrderedDict
Expand Down Expand Up @@ -38,19 +40,20 @@ def __init__(self, client_worker: "Worker", client_id: str, metadata: list):
self._metadata = metadata
self.data_thread = self._start_datathread()

# Track outstanding requests to resend in case of disconnection
# Track outstanding requests to resend in case of disconnection.
# Maps request serial number to request.
self.outstanding_requests: Dict[int, Any] = OrderedDict()

# Serialize access to all mutable internal states: self.request_queue,
# self.ready_data, self.asyncio_waiting_data,
# self._in_shutdown, self._req_id, self.outstanding_requests and
# calling self._next_id()
# self.ready_data, self.asyncio_waiting_data, self._in_shutdown and
# self.outstanding_requests.
self.lock = threading.Lock()

# Waiting for response or shutdown.
self.cv = threading.Condition(lock=self.lock)

self.request_queue = queue.Queue()
# Maps request ID (Python object address) to response.
self.ready_data: Dict[int, Any] = {}
# NOTE: Dictionary insertion is guaranteed to complete before lookup
# and/or removal because of synchronization via the request_queue.
Expand Down Expand Up @@ -81,6 +84,29 @@ def _start_datathread(self) -> threading.Thread:
daemon=True,
)

# A helper that takes in (request, callback) from queue, sets up tracking
# of the requests and yields the requests as a stream.
def _requests(self):
while True:
request = self.request_queue.get()
if request is None:
# Stop when client signals shutdown.
return
req, callback = request
assert req is not None
# req.req_id >= 1 when reconnecting, and is 0 for special types.
if req.req_id < 1 and req.WhichOneof("type") not in [
"acknowledge",
"connection_cleanup",
]:
with self.lock:
req_id = self._next_id()
req.req_id = req_id
self.outstanding_requests[req_id] = (req, callback)
if callback is not None:
self.asyncio_waiting_data[req_id] = callback
yield req

def _data_main(self) -> None:
reconnecting = False
try:
Expand All @@ -89,8 +115,9 @@ def _data_main(self) -> None:
self.client_worker.channel
)
metadata = self._metadata + [("reconnecting", str(reconnecting))]

resp_stream = stub.Datapath(
iter(self.request_queue.get, None),
request_iterator=self._requests(),
metadata=metadata,
wait_for_ready=True,
)
Expand All @@ -101,11 +128,11 @@ def _data_main(self) -> None:
except grpc.RpcError as e:
reconnecting = self._can_reconnect(e)
if not reconnecting:
self._last_exception = e
self._last_exception = traceback.format_exc()
return
self._reconnect_channel()
except Exception as e:
self._last_exception = e
except Exception:
self._last_exception = traceback.format_exc()
finally:
logger.debug("Shutting down data channel.")
self._shutdown()
Expand All @@ -118,6 +145,17 @@ def _process_response(self, response: Any) -> None:
# This is not being waited for.
logger.debug(f"Got unawaited response {response}")
return
with self.lock:
# Update outstanding requests
if response.req_id in self.outstanding_requests:
req, callback = self.outstanding_requests.pop(response.req_id)
# Acknowledge response
self._acknowledge(response.req_id)
else:
logger.warning(
f"Receiving response without outstanding request: {response}"
)
return
if response.req_id in self.asyncio_waiting_data:
try:
# NOTE: calling self.asyncio_waiting_data.pop() results
Expand All @@ -130,15 +168,9 @@ def _process_response(self, response: Any) -> None:
callback(response)
except Exception:
logger.exception("Callback error:")
with self.lock:
# Update outstanding requests
if response.req_id in self.outstanding_requests:
del self.outstanding_requests[response.req_id]
# Acknowledge response
self._acknowledge(response.req_id)
else:
with self.lock:
self.ready_data[response.req_id] = response
self.ready_data[id(req)] = response
self.cv.notify_all()

def _can_reconnect(self, e: grpc.RpcError) -> bool:
Expand All @@ -147,8 +179,7 @@ def _can_reconnect(self, e: grpc.RpcError) -> bool:
Returns True if the error can be recovered from, False otherwise.
"""
if not self.client_worker._can_reconnect(e):
logger.error("Unrecoverable error in data channel.")
logger.debug(e)
logger.exception("Unrecoverable error in data channel.")
return False
logger.debug("Recoverable error in data channel.")
logger.debug(e)
Expand Down Expand Up @@ -179,6 +210,16 @@ def _shutdown(self) -> None:
for callback in callbacks:
if callback:
callback(err)
while True:
try:
_, callback = self.request_queue.get_nowait()
except queue.Empty:
break
except Exception:
logger.exception("Bad input data")
continue
if callback:
callback(err)
# Since self._in_shutdown is set to True, no new item
# will be added to self.asyncio_waiting_data

Expand All @@ -195,8 +236,11 @@ def _acknowledge(self, req_id: int) -> None:
self._acknowledge_counter += 1
if self._acknowledge_counter % ACKNOWLEDGE_BATCH_SIZE == 0:
self.request_queue.put(
ray_client_pb2.DataRequest(
acknowledge=ray_client_pb2.AcknowledgeRequest(req_id=req_id)
(
ray_client_pb2.DataRequest(
acknowledge=ray_client_pb2.AcknowledgeRequest(req_id=req_id)
),
None,
)
)

Expand Down Expand Up @@ -233,10 +277,22 @@ def _reconnect_channel(self) -> None:

# Recreate the request queue, and resend outstanding requests
with self.lock:
self.request_queue = queue.Queue()
for request in self.outstanding_requests.values():
# Resend outstanding requests
self.request_queue.put(request)
new_queue = queue.Queue()
# Fill the new request queue first with outstanding requests, which
# have lower req_id. Must use the order of req_id.
for req_id, request in sorted(self.outstanding_requests.items()):
# Re-queue outstanding requests (and callbacks).
new_queue.put(request)
self.request_queue, prev_queue = new_queue, self.request_queue
# Transfer remaining requests from the previous request queue.
# NOTE: prev_queue has concurrent consumers.
while True:
try:
req, callback = prev_queue.get_nowait()
except queue.Empty:
return
if req.req_id not in self.outstanding_requests:
self.request_queue.put((req, callback))

def close(self) -> None:
thread = None
Expand All @@ -248,10 +304,14 @@ def close(self) -> None:
if self.request_queue is not None:
# Intentional shutdown, tell server it can clean up the
# connection immediately and ignore the reconnect grace period.
cleanup_request = ray_client_pb2.DataRequest(
connection_cleanup=ray_client_pb2.ConnectionCleanupRequest()
self.request_queue.put(
(
ray_client_pb2.DataRequest(
connection_cleanup=ray_client_pb2.ConnectionCleanupRequest()
),
None,
)
)
self.request_queue.put(cleanup_request)
self.request_queue.put(None)
if self.data_thread is not None:
thread = self.data_thread
Expand All @@ -262,37 +322,40 @@ def close(self) -> None:
def _blocking_send(
self, req: ray_client_pb2.DataRequest
) -> ray_client_pb2.DataResponse:
self.request_queue.put((req, None))
with self.lock:
self._check_shutdown()
req_id = self._next_id()
req.req_id = req_id
self.request_queue.put(req)
self.outstanding_requests[req_id] = req

self.cv.wait_for(lambda: req_id in self.ready_data or self._in_shutdown)
self.cv.wait_for(lambda: id(req) in self.ready_data or self._in_shutdown)
self._check_shutdown()

data = self.ready_data[req_id]
del self.ready_data[req_id]
del self.outstanding_requests[req_id]
self._acknowledge(req_id)
resp = self.ready_data[id(req)]
del self.ready_data[id(req)]

return data
return resp

def _async_send(
self,
req: ray_client_pb2.DataRequest,
callback: Optional[ResponseCallable] = None,
check_shutdown: bool = True,
) -> None:
with self.lock:
self._check_shutdown()
req_id = self._next_id()
req.req_id = req_id
self.asyncio_waiting_data[req_id] = callback
self.outstanding_requests[req_id] = req
self.request_queue.put(req)
def nop(ignored):
return

# Must hold self.lock when calling this function.
if callback is None:
callback = nop
self.request_queue.put((req, callback))

# TODO: pass exception to the callback as well?
if check_shutdown:
with self.lock:
self._check_shutdown()

# The purpose of this function is to disconnect the Ray client when a
# connection issue is encountered. It avoids running in the data streaming
# thread (self.data_thread) which may result in deadlock, but it
# opportunistically runs when a blocking request is attempted.
def _check_shutdown(self):
assert self.lock.locked()
if not self._in_shutdown:
Expand Down Expand Up @@ -382,13 +445,14 @@ def PutObject(
resp = self._blocking_send(datareq)
return resp.put

def ReleaseObject(
self, request: ray_client_pb2.ReleaseRequest, context=None
) -> None:
def ReleaseObject(self, request: ray_client_pb2.ReleaseRequest) -> None:
datareq = ray_client_pb2.DataRequest(
release=request,
)
self._async_send(datareq)
# ReleaseObject() is called inside ClientObjectRef destructor, so it
# cannot acquire a lock. Avoiding checking shutdown which acquires a
# lock.
self._async_send(datareq, check_shutdown=False)

def Schedule(self, request: ray_client_pb2.ClientTask, callback: ResponseCallable):
datareq = ray_client_pb2.DataRequest(task=request)
Expand Down
22 changes: 15 additions & 7 deletions python/ray/util/client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,21 @@ def _can_reconnect(self, e: grpc.RpcError) -> bool:
# Unrecoverable error -- These errors are specifically raised
# by the server's application logic
return False
if e.code() == grpc.StatusCode.INTERNAL:
details = e.details()
if details == "Exception serializing request!":
# The client failed tried to send a bad request (for example,
# passing "None" instead of a valid grpc message). Don't
# try to reconnect/retry.
return False
if (
e.code() == grpc.StatusCode.INTERNAL
and e.details() == "Exception serializing request!"
):
# The client failed tried to send a bad request. Don't
# try to reconnect/retry.
return False
if (
e.code() == grpc.StatusCode.UNKNOWN
and e.details() == "Exception iterating requests!"
):
# The client failed tried to send a bad request (for example,
# passing "None" instead of a valid grpc message). Don't try to
# reconnect/retry.
return False
# All other errors can be treated as recoverable
return True

Expand Down

0 comments on commit 71f6359

Please sign in to comment.