Skip to content

Commit

Permalink
[client] Chunk ClientTask's (ray-project#24555)
Browse files Browse the repository at this point in the history
Adds support for chunking large schedule calls. Needed to support ray.remote calls with more than 2GiB of arguments.

Deprecates the args and kwargs fields of ClientTask and replaces them with a data field that contains a tuple of the serialized args and kwargs fields, which can be chunked and reassembled more easily using the same logic as PutRequest's.
  • Loading branch information
ckw017 authored May 11, 2022
1 parent 5a43b07 commit 11650b5
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 47 deletions.
42 changes: 42 additions & 0 deletions python/ray/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import _thread
from unittest.mock import patch
import numpy as np
from ray.util.client.common import OBJECT_TRANSFER_CHUNK_SIZE

import ray.util.client.server.server as ray_client_server
from ray.tests.client_test_utils import create_remote_signal_actor
Expand Down Expand Up @@ -791,5 +792,46 @@ def test_empty_objects(ray_start_regular_shared):
assert ray.get(ref) == obj


def test_large_remote_call(ray_start_regular_shared):
"""
Test remote calls with large (multiple chunk) arguments
"""
with ray_start_client_server() as ray:

@ray.remote
def f(large_obj):
return large_obj.shape

@ray.remote
def f2(*args):
assert args[0] == 123
return args[1].shape

@ray.remote
def f3(*args, **kwargs):
assert args[0] == "a"
assert args[1] == "b"
return kwargs["large_obj"].shape

# 1024x1024x16 f64's =~ 128 MiB. Chunking size is 64 MiB, so guarantees
# that transferring argument requires multiple chunks.
assert OBJECT_TRANSFER_CHUNK_SIZE < 2 ** 20 * 128
large_obj = np.random.random((1024, 1024, 16))
assert ray.get(f.remote(large_obj)) == (1024, 1024, 16)
assert ray.get(f2.remote(123, large_obj)) == (1024, 1024, 16)
assert ray.get(f3.remote("a", "b", large_obj=large_obj)) == (1024, 1024, 16)

@ray.remote
class SomeActor:
def __init__(self, large_obj):
self.inner = large_obj

def some_method(self, large_obj):
return large_obj.shape == self.inner.shape

a = SomeActor.remote(large_obj)
assert ray.get(a.some_method.remote(large_obj))


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
28 changes: 28 additions & 0 deletions python/ray/tests/test_client_reconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,34 @@ def fail_halfway(_):
assert result.shape == (1024, 1024, 128)


def test_disconnect_during_large_schedule():
"""
Disconnect during a remote call with a large (multi-chunk) argument.
"""
i = 0
started = False

def fail_halfway(_):
# Inject an error halfway through the object transfer
nonlocal i, started
if not started:
return
i += 1
if i == 8:
raise RuntimeError

@ray.remote
def f(a):
return a.shape

with start_middleman_server(on_data_request=fail_halfway):
started = True
a = np.random.random((1024, 1024, 128))
result = ray.get(f.remote(a))
assert i > 8 # Check that the failure was injected
assert result == (1024, 1024, 128)


def test_valid_actor_state():
"""
Repeatedly inject errors in the middle of mutating actor calls. Check
Expand Down
2 changes: 1 addition & 1 deletion python/ray/util/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2022-03-16"
CURRENT_PROTOCOL_VERSION = "2022-05-06"


class _ClientContext:
Expand Down
34 changes: 33 additions & 1 deletion python/ray/util/client/dataclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,35 @@ def chunk_put(req: ray_client_pb2.DataRequest):
yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk)


def chunk_task(req: ray_client_pb2.DataRequest):
"""
Chunks a client task. Doing this lazily is important with large arguments,
since taking slices of bytes objects does a copy. This means if we
immediately materialized every chunk of a large argument and inserted them
into the result_queue, we would effectively double the memory needed
on the client to handle the task.
"""
total_size = len(req.task.data)
assert total_size > 0, "Cannot chunk object with missing data"
total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
for chunk_id in range(0, total_chunks):
start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
chunk = ray_client_pb2.ClientTask(
type=req.task.type,
name=req.task.name,
payload_id=req.task.payload_id,
client_id=req.task.client_id,
options=req.task.options,
baseline_options=req.task.baseline_options,
namespace=req.task.namespace,
data=req.task.data[start:end],
chunk_id=chunk_id,
total_chunks=total_chunks,
)
yield ray_client_pb2.DataRequest(req_id=req.req_id, task=chunk)


class ChunkCollector:
"""
This object collects chunks from async get requests via __call__, and
Expand Down Expand Up @@ -211,8 +240,11 @@ def _requests(self):
if req is None:
# Stop when client signals shutdown.
return
if req.WhichOneof("type") == "put":
req_type = req.WhichOneof("type")
if req_type == "put":
yield from chunk_put(req)
elif req_type == "task":
yield from chunk_task(req)
else:
yield req

Expand Down
39 changes: 31 additions & 8 deletions python/ray/util/client/server/dataservicer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
from ray.util.client.server.server_pickler import loads_from_client
import ray
import logging
import grpc
Expand Down Expand Up @@ -55,12 +56,16 @@ def _should_cache(req: ray_client_pb2.DataRequest) -> bool:
wrapped up the data connection by this point.
- puts: We should only cache when we receive the final chunk, since
any earlier chunks won't generate a response
- tasks: We should only cache when we receive the final chunk,
since any earlier chunks won't generate a response
"""
req_type = req.WhichOneof("type")
if req_type == "get" and req.get.asynchronous:
return False
if req_type == "put":
return req.put.chunk_id == req.put.total_chunks - 1
if req_type == "task":
return req.task.chunk_id == req.task.total_chunks - 1
return req_type not in ("acknowledge", "connection_cleanup")


Expand All @@ -86,22 +91,25 @@ def fill_queue(

class ChunkCollector:
"""
Helper class for collecting chunks from PutObject calls
Helper class for collecting chunks from PutObject or ClientTask messages
"""

def __init__(self):
self.curr_req_id = None
self.last_seen_chunk_id = -1
self.data = bytearray()

def add_chunk(self, req: ray_client_pb2.DataRequest):
def add_chunk(
self,
req: ray_client_pb2.DataRequest,
chunk: Union[ray_client_pb2.PutRequest, ray_client_pb2.ClientTask],
):
if self.curr_req_id is not None and self.curr_req_id != req.req_id:
raise RuntimeError(
"Expected to receive a chunk from request with id "
f"{self.curr_req_id}, but found {req.req_id} instead."
)
self.curr_req_id = req.req_id
chunk = req.put
next_chunk = self.last_seen_chunk_id + 1
if chunk.chunk_id < next_chunk:
# Repeated chunk, ignore
Expand Down Expand Up @@ -139,7 +147,10 @@ def __init__(self, basic_service: "RayletServicer"):
self.stopped = Event()
# Helper for collecting chunks from PutObject calls. Assumes that
# that put requests from different objects aren't interleaved.
self.chunk_collector = ChunkCollector()
self.put_request_chunk_collector = ChunkCollector()
# Helper for collecting chunks from ClientTask calls. Assumes that
# schedule requests from different remote calls aren't interleaved.
self.client_task_chunk_collector = ChunkCollector()

def Datapath(self, request_iterator, context):
start_time = time.time()
Expand Down Expand Up @@ -213,13 +224,15 @@ def Datapath(self, request_iterator, context):
get_resp = self.basic_service._get_object(req.get, client_id)
resp = ray_client_pb2.DataResponse(get=get_resp)
elif req_type == "put":
if not self.chunk_collector.add_chunk(req):
if not self.put_request_chunk_collector.add_chunk(req, req.put):
# Put request still in progress
continue
put_resp = self.basic_service._put_object(
self.chunk_collector.data, req.put.client_ref_id, client_id
self.put_request_chunk_collector.data,
req.put.client_ref_id,
client_id,
)
self.chunk_collector.reset()
self.put_request_chunk_collector.reset()
resp = ray_client_pb2.DataResponse(put=put_resp)
elif req_type == "release":
released = []
Expand Down Expand Up @@ -249,7 +262,17 @@ def Datapath(self, request_iterator, context):
continue
elif req_type == "task":
with self.clients_lock:
resp_ticket = self.basic_service.Schedule(req.task, context)
task = req.task
if not self.client_task_chunk_collector.add_chunk(req, task):
# Not all serialized arguments have arrived
continue
arglist, kwargs = loads_from_client(
self.client_task_chunk_collector.data, self.basic_service
)
self.client_task_chunk_collector.reset()
resp_ticket = self.basic_service.Schedule(
req.task, arglist, kwargs, context
)
resp = ray_client_pb2.DataResponse(task_ticket=resp_ticket)
elif req_type == "terminate":
with self.clients_lock:
Expand Down
48 changes: 24 additions & 24 deletions python/ray/util/client/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pickle

import threading
from typing import Any
from typing import Any, List
from typing import Dict
from typing import Set
from typing import Optional
Expand All @@ -34,7 +34,6 @@
)
from ray import ray_constants
from ray.util.client.server.proxier import serve_proxier
from ray.util.client.server.server_pickler import convert_from_arg
from ray.util.client.server.server_pickler import dumps_from_server
from ray.util.client.server.server_pickler import loads_from_client
from ray.util.client.server.dataservicer import DataServicer
Expand Down Expand Up @@ -548,9 +547,12 @@ def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
remaining_object_ids=remaining_object_ids,
)

@_use_response_cache
def Schedule(
self, task: ray_client_pb2.ClientTask, context=None
self,
task: ray_client_pb2.ClientTask,
arglist: List[Any],
kwargs: Dict[str, Any],
context=None,
) -> ray_client_pb2.ClientTaskTicket:
logger.debug(
"schedule: %s %s"
Expand All @@ -559,11 +561,11 @@ def Schedule(
try:
with disable_client_hook():
if task.type == ray_client_pb2.ClientTask.FUNCTION:
result = self._schedule_function(task, context)
result = self._schedule_function(task, arglist, kwargs, context)
elif task.type == ray_client_pb2.ClientTask.ACTOR:
result = self._schedule_actor(task, context)
result = self._schedule_actor(task, arglist, kwargs, context)
elif task.type == ray_client_pb2.ClientTask.METHOD:
result = self._schedule_method(task, context)
result = self._schedule_method(task, arglist, kwargs, context)
elif task.type == ray_client_pb2.ClientTask.NAMED_ACTOR:
result = self._schedule_named_actor(task, context)
else:
Expand All @@ -580,12 +582,15 @@ def Schedule(
)

def _schedule_method(
self, task: ray_client_pb2.ClientTask, context=None
self,
task: ray_client_pb2.ClientTask,
arglist: List[Any],
kwargs: Dict[str, Any],
context=None,
) -> ray_client_pb2.ClientTaskTicket:
actor_handle = self.actor_refs.get(task.payload_id)
if actor_handle is None:
raise Exception("Can't run an actor the server doesn't have a handle for")
arglist, kwargs = self._convert_args(task.args, task.kwargs)
method = getattr(actor_handle, task.name)
opts = decode_options(task.options)
if opts is not None:
Expand All @@ -595,13 +600,15 @@ def _schedule_method(
return ray_client_pb2.ClientTaskTicket(return_ids=ids)

def _schedule_actor(
self, task: ray_client_pb2.ClientTask, context=None
self,
task: ray_client_pb2.ClientTask,
arglist: List[Any],
kwargs: Dict[str, Any],
context=None,
) -> ray_client_pb2.ClientTaskTicket:
remote_class = self.lookup_or_register_actor(
task.payload_id, task.client_id, decode_options(task.baseline_options)
)

arglist, kwargs = self._convert_args(task.args, task.kwargs)
opts = decode_options(task.options)
if opts is not None:
remote_class = remote_class.options(**opts)
Expand All @@ -612,12 +619,15 @@ def _schedule_actor(
return ray_client_pb2.ClientTaskTicket(return_ids=[actor._actor_id.binary()])

def _schedule_function(
self, task: ray_client_pb2.ClientTask, context=None
self,
task: ray_client_pb2.ClientTask,
arglist: List[Any],
kwargs: Dict[str, Any],
context=None,
) -> ray_client_pb2.ClientTaskTicket:
remote_func = self.lookup_or_register_func(
task.payload_id, task.client_id, decode_options(task.baseline_options)
)
arglist, kwargs = self._convert_args(task.args, task.kwargs)
opts = decode_options(task.options)
if opts is not None:
remote_func = remote_func.options(**opts)
Expand All @@ -638,16 +648,6 @@ def _schedule_named_actor(
self.named_actors.add(bin_actor_id)
return ray_client_pb2.ClientTaskTicket(return_ids=[actor._actor_id.binary()])

def _convert_args(self, arg_list, kwarg_map):
argout = []
for arg in arg_list:
t = convert_from_arg(arg, self)
argout.append(t)
kwargout = {}
for k in kwarg_map:
kwargout[k] = convert_from_arg(kwarg_map[k], self)
return argout, kwargout

def lookup_or_register_func(
self, id: bytes, client_id: str, options: Optional[Dict]
) -> ray.remote_function.RemoteFunction:
Expand Down
5 changes: 0 additions & 5 deletions python/ray/util/client/server/server_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

if TYPE_CHECKING:
from ray.util.client.server.server import RayletServicer
import ray.core.generated.ray_client_pb2 as ray_client_pb2

if sys.version_info < (3, 8):
try:
Expand Down Expand Up @@ -130,7 +129,3 @@ def loads_from_client(
return ClientUnpickler(
server_instance, file, fix_imports=fix_imports, encoding=encoding
).load()


def convert_from_arg(pb: "ray_client_pb2.Arg", server: "RayletServicer") -> Any:
return loads_from_client(pb.data, server)
Loading

0 comments on commit 11650b5

Please sign in to comment.