Skip to content

Commit

Permalink
[Core] Port GcsSubscriber to Cython (ray-project#35094)
Browse files Browse the repository at this point in the history
Next step in ray-project#34393

---------

Signed-off-by: Philipp Moritz <[email protected]>
  • Loading branch information
pcmoritz authored May 31, 2023
1 parent f1f714c commit db64a12
Show file tree
Hide file tree
Showing 19 changed files with 575 additions and 324 deletions.
3 changes: 1 addition & 2 deletions dashboard/modules/actor/tests/test_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import requests

import ray
import ray._private.gcs_pubsub as gcs_pubsub
import ray.dashboard.utils as dashboard_utils
from ray._private.test_utils import format_web_url, wait_until_server_available
from ray.dashboard.modules.actor import actor_consts
Expand Down Expand Up @@ -126,7 +125,7 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
address_info = ray_start_with_dashboard

sub = gcs_pubsub.GcsActorSubscriber(address=address_info["gcs_address"])
sub = ray._raylet._TestOnly_GcsActorSubscriber(address=address_info["gcs_address"])
sub.subscribe()

@ray.remote
Expand Down
12 changes: 6 additions & 6 deletions dashboard/tests/test_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ def test_agent_report_unexpected_raylet_death(shutdown_only):
errors = get_error_message(p, 1, ray_constants.RAYLET_DIED_ERROR)
assert len(errors) == 1, errors
err = errors[0]
assert err.type == ray_constants.RAYLET_DIED_ERROR
assert "Termination is unexpected." in err.error_message, err.error_message
assert "Raylet logs:" in err.error_message, err.error_message
assert err["type"] == ray_constants.RAYLET_DIED_ERROR
assert "Termination is unexpected." in err["error_message"], err["error_message"]
assert "Raylet logs:" in err["error_message"], err["error_message"]
assert (
os.path.getsize(os.path.join(node.get_session_dir_path(), "logs", "raylet.out"))
< 1 * 1024**2
Expand Down Expand Up @@ -268,9 +268,9 @@ def test_agent_report_unexpected_raylet_death_large_file(shutdown_only):
errors = get_error_message(p, 1, ray_constants.RAYLET_DIED_ERROR)
assert len(errors) == 1, errors
err = errors[0]
assert err.type == ray_constants.RAYLET_DIED_ERROR
assert "Termination is unexpected." in err.error_message, err.error_message
assert "Raylet logs:" in err.error_message, err.error_message
assert err["type"] == ray_constants.RAYLET_DIED_ERROR
assert "Termination is unexpected." in err["error_message"], err["error_message"]
assert "Raylet logs:" in err["error_message"], err["error_message"]


@pytest.mark.parametrize(
Expand Down
254 changes: 1 addition & 253 deletions python/ray/_private/gcs_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from collections import deque
import logging
import random
import threading
from typing import Optional, Tuple, List
from typing import Tuple, List

import grpc
from ray._private.utils import get_or_create_event_loop
Expand Down Expand Up @@ -158,257 +157,6 @@ def _pop_actors(queue, batch_size=100):
return msgs


class _SyncSubscriber(_SubscriberBase):
def __init__(
self,
pubsub_channel_type,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(worker_id)

if address:
assert channel is None, "address and channel cannot both be specified"
channel = gcs_utils.create_gcs_channel(address)
else:
assert channel is not None, "One of address and channel must be specified"
# GRPC stub to GCS pubsub.
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)

# Type of the channel.
self._channel = pubsub_channel_type
# Protects multi-threaded read and write of self._queue.
self._lock = threading.Lock()
# A queue of received PubMessage.
self._queue = deque()
# Indicates whether the subscriber has closed.
self._close = threading.Event()

def subscribe(self) -> None:
"""Registers a subscription for the subscriber's channel type.
Before the registration, published messages in the channel will not be
saved for the subscriber.
"""
with self._lock:
if self._close.is_set():
return
req = self._subscribe_request(self._channel)
self._stub.GcsSubscriberCommandBatch(req, timeout=30)

def _poll_locked(self, timeout=None) -> None:
assert self._lock.locked()

# Poll until data becomes available.
while len(self._queue) == 0:
if self._close.is_set():
return

fut = self._stub.GcsSubscriberPoll.future(
self._poll_request(), timeout=timeout
)
# Wait for result to become available, or cancel if the
# subscriber has closed.
while True:
try:
# Use 1s timeout to check for subscriber closing
# periodically.
fut.result(timeout=1)
break
except grpc.FutureTimeoutError:
# Subscriber has closed. Cancel inflight request and
# return from polling.
if self._close.is_set():
fut.cancel()
return
# GRPC has not replied, continue waiting.
continue
except grpc.RpcError as e:
if self._should_terminate_polling(e):
return
raise

if fut.done():
self._last_batch_size = len(fut.result().pub_messages)
if fut.result().publisher_id != self._publisher_id:
if self._publisher_id != "":
logger.debug(
f"replied publisher_id {fut.result().publisher_id} "
f"different from {self._publisher_id}, this should "
"only happens during gcs failover."
)
self._publisher_id = fut.result().publisher_id
self._max_processed_sequence_id = 0

for msg in fut.result().pub_messages:
if msg.sequence_id <= self._max_processed_sequence_id:
logger.warn(f"Ignoring out of order message {msg}")
continue
self._max_processed_sequence_id = msg.sequence_id
if msg.channel_type != self._channel:
logger.warn(f"Ignoring message from unsubscribed channel {msg}")
continue
self._queue.append(msg)

def close(self) -> None:
"""Closes the subscriber and its active subscription."""

# Mark close to terminate inflight polling and prevent future requests.
if self._close.is_set():
return
self._close.set()
req = self._unsubscribe_request(channels=[self._channel])
try:
self._stub.GcsSubscriberCommandBatch(req, timeout=5)
except Exception:
pass
self._stub = None


class GcsErrorSubscriber(_SyncSubscriber):
"""Subscriber to error info. Thread safe.
Usage example:
subscriber = GcsErrorSubscriber()
# Subscribe to the error channel.
subscriber.subscribe()
...
while running:
error_id, error_data = subscriber.poll()
......
# Unsubscribe from the error channels.
subscriber.close()
"""

def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.RAY_ERROR_INFO_CHANNEL, worker_id, address, channel)

def poll(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
"""Polls for new error messages.
Returns:
A tuple of error message ID and ErrorTableData proto message,
or None, None if polling times out or subscriber closed.
"""
with self._lock:
self._poll_locked(timeout=timeout)
return self._pop_error_info(self._queue)


class GcsLogSubscriber(_SyncSubscriber):
"""Subscriber to logs. Thread safe.
Usage example:
subscriber = GcsLogSubscriber()
# Subscribe to the log channel.
subscriber.subscribe()
...
while running:
log = subscriber.poll()
......
# Unsubscribe from the log channel.
subscriber.close()
"""

def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.RAY_LOG_CHANNEL, worker_id, address, channel)

def poll(self, timeout=None) -> Optional[dict]:
"""Polls for new log messages.
Returns:
A dict containing a batch of log lines and their metadata,
or None if polling times out or subscriber closed.
"""
with self._lock:
self._poll_locked(timeout=timeout)
return self._pop_log_batch(self._queue)


class GcsFunctionKeySubscriber(_SyncSubscriber):
"""Subscriber to function(and actor class) dependency keys. Thread safe.
Usage example:
subscriber = GcsFunctionKeySubscriber()
# Subscribe to the function key channel.
subscriber.subscribe()
...
while running:
key = subscriber.poll()
......
# Unsubscribe from the function key channel.
subscriber.close()
"""

def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(
pubsub_pb2.RAY_PYTHON_FUNCTION_CHANNEL, worker_id, address, channel
)

def poll(self, timeout=None) -> Optional[bytes]:
"""Polls for new function key messages.
Returns:
A byte string of function key.
None if polling times out or subscriber closed.
"""
with self._lock:
self._poll_locked(timeout=timeout)
return self._pop_function_key(self._queue)


# Test-only
class GcsActorSubscriber(_SyncSubscriber):
"""Subscriber to actor updates. Thread safe.
Usage example:
subscriber = GcsActorSubscriber()
# Subscribe to the actor channel.
subscriber.subscribe()
...
while running:
actor_data = subscriber.poll()
......
# Unsubscribe from the channel.
subscriber.close()
"""

def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.GCS_ACTOR_CHANNEL, worker_id, address, channel)

def poll(self, timeout=None) -> List[Tuple[bytes, str]]:
"""Polls for new actor messages.
Returns:
A byte string of function key.
None if polling times out or subscriber closed.
"""
with self._lock:
self._poll_locked(timeout=timeout)
return self._pop_actors(self._queue, batch_size=1)


class GcsAioPublisher(_PublisherBase):
"""Publisher to GCS. Uses async io."""

Expand Down
11 changes: 7 additions & 4 deletions python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import ray._private.memory_monitor as memory_monitor
import ray._private.services
import ray._private.utils
from ray._private.gcs_pubsub import GcsErrorSubscriber, GcsLogSubscriber
from ray._private.internal_api import memory_summary
from ray._private.tls_utils import generate_self_signed_tls_certs
from ray._raylet import GcsClientOptions, GlobalStateAccessor
Expand Down Expand Up @@ -890,7 +889,9 @@ def get_non_head_nodes(cluster):

def init_error_pubsub():
"""Initialize error info pub/sub"""
s = GcsErrorSubscriber(address=ray._private.worker.global_worker.gcs_client.address)
s = ray._raylet.GcsErrorSubscriber(
address=ray._private.worker.global_worker.gcs_client.address
)
s.subscribe()
return s

Expand All @@ -908,7 +909,7 @@ def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
if not error_data:
# Timed out before any data is received.
break
if error_type is None or error_type == error_data.type:
if error_type is None or error_type == error_data["type"]:
msgs.append(error_data)
else:
time.sleep(0.01)
Expand All @@ -918,7 +919,9 @@ def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):

def init_log_pubsub():
"""Initialize log pub/sub"""
s = GcsLogSubscriber(address=ray._private.worker.global_worker.gcs_client.address)
s = ray._raylet.GcsLogSubscriber(
address=ray._private.worker.global_worker.gcs_client.address
)
s.subscribe()
return s

Expand Down
Loading

0 comments on commit db64a12

Please sign in to comment.