Skip to content

Commit

Permalink
[Remove Redis Pubsub 1/n] Remove enable_gcs_pubsub() (ray-project#2…
Browse files Browse the repository at this point in the history
…3189)

GCS pubsub has been the default for awhile. There is little chance that we would need to revert back to Redis pubsub in future. This is the step in removing Redis pubsub, by first removing the `enable_gcs_pubsub()` feature guard.
  • Loading branch information
mwtian authored Mar 16, 2022
1 parent 678d23f commit 72ef9f9
Show file tree
Hide file tree
Showing 15 changed files with 84 additions and 374 deletions.
11 changes: 1 addition & 10 deletions dashboard/modules/actor/tests/test_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import ray
import pytest
import ray.dashboard.utils as dashboard_utils
import ray._private.gcs_utils as gcs_utils
import ray._private.gcs_pubsub as gcs_pubsub
from ray.dashboard.tests.conftest import * # noqa
from ray.dashboard.modules.actor import actor_consts
Expand Down Expand Up @@ -228,15 +227,7 @@ def __init__(self):
def handle_pub_messages(msgs, timeout, expect_num):
start_time = time.time()
while time.time() - start_time < timeout and len(msgs) < expect_num:
if gcs_pubsub.gcs_pubsub_enabled():
_, actor_data = sub.poll(timeout=timeout)
else:
msg = sub.get_message()
if msg is None:
time.sleep(0.01)
continue
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
actor_data = gcs_utils.ActorTableData.FromString(pubsub_msg.data)
_, actor_data = sub.poll(timeout=timeout)
if actor_data is None:
continue
msgs.append(actor_data)
Expand Down
60 changes: 19 additions & 41 deletions dashboard/modules/reporter/reporter_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import aiohttp.web

import ray
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.experimental.internal_kv as internal_kv
Expand All @@ -18,7 +17,7 @@
)
from ray.core.generated import reporter_pb2
from ray.core.generated import reporter_pb2_grpc
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioResourceUsageSubscriber
from ray._private.gcs_pubsub import GcsAioResourceUsageSubscriber
from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter
from ray.dashboard.datacenter import DataSource

Expand Down Expand Up @@ -148,45 +147,24 @@ async def run(self, server):
# Need daemon True to avoid dashboard hangs at exit.
self.service_discovery.daemon = True
self.service_discovery.start()
if gcs_pubsub_enabled():
gcs_addr = self._dashboard_head.gcs_address
subscriber = GcsAioResourceUsageSubscriber(gcs_addr)
await subscriber.subscribe()

while True:
try:
# The key is b'RAY_REPORTER:{node id hex}',
# e.g. b'RAY_REPORTER:2b4fbd...'
key, data = await subscriber.poll()
if key is None:
continue
data = json.loads(data)
node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = data
except Exception:
logger.exception(
"Error receiving node physical stats " "from reporter agent."
)
else:
from aioredis.pubsub import Receiver

receiver = Receiver()
aioredis_client = self._dashboard_head.aioredis_client
reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
await aioredis_client.psubscribe(receiver.pattern(reporter_key))
logger.info(f"Subscribed to {reporter_key}")

async for sender, msg in receiver.iter():
try:
key, data = msg
data = json.loads(ray._private.utils.decode(data))
key = key.decode("utf-8")
node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = data
except Exception:
logger.exception(
"Error receiving node physical stats " "from reporter agent."
)
gcs_addr = self._dashboard_head.gcs_address
subscriber = GcsAioResourceUsageSubscriber(gcs_addr)
await subscriber.subscribe()

while True:
try:
# The key is b'RAY_REPORTER:{node id hex}',
# e.g. b'RAY_REPORTER:2b4fbd...'
key, data = await subscriber.poll()
if key is None:
continue
data = json.loads(data)
node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = data
except Exception:
logger.exception(
"Error receiving node physical stats from reporter agent."
)

@staticmethod
def is_minimal_module():
Expand Down
11 changes: 3 additions & 8 deletions dashboard/tests/test_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
run_string_as_driver,
wait_until_succeeded_without_exception,
)
from ray._private.gcs_pubsub import gcs_pubsub_enabled
from ray.ray_constants import DEBUG_AUTOSCALING_STATUS_LEGACY, DEBUG_AUTOSCALING_ERROR
from ray.dashboard import dashboard
import ray.dashboard.consts as dashboard_consts
Expand Down Expand Up @@ -701,13 +700,9 @@ def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):

gcs_server_proc.kill()
gcs_server_proc.wait()
if gcs_pubsub_enabled():
# When pubsub enabled, the exits comes from pubsub errored.
# TODO: Fix this exits logic for pubsub
assert dashboard_proc.wait(10) != 0
else:
# The dashboard exits by os._exit(-1)
assert dashboard_proc.wait(10) == 255

# The dashboard exits by os._exit(-1)
assert dashboard_proc.wait(10) == 255


if __name__ == "__main__":
Expand Down
9 changes: 2 additions & 7 deletions python/ray/_private/function_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,8 @@ def export_key(self, key):
break
# Notify all subscribers that there is a new function exported. Note
# that the notification doesn't include any actual data.
if self._worker.gcs_pubsub_enabled:
# TODO(mwtian) implement per-job notification here.
self._worker.gcs_publisher.publish_function_key(key)
else:
self._worker.redis_client.lpush(
make_exports_prefix(self._worker.current_job_id), "a"
)
# TODO(mwtian) implement per-job notification here.
self._worker.gcs_publisher.publish_function_key(key)

def export(self, remote_function):
"""Pickle a remote function and export it to redis.
Expand Down
6 changes: 0 additions & 6 deletions python/ray/_private/gcs_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import ray._private.gcs_utils as gcs_utils
import ray._private.logging_utils as logging_utils
from ray._raylet import Config
from ray.core.generated.gcs_pb2 import ErrorTableData
from ray.core.generated import dependency_pb2
from ray.core.generated import gcs_service_pb2_grpc
Expand All @@ -30,11 +29,6 @@
MAX_GCS_PUBLISH_RETRIES = 60


def gcs_pubsub_enabled():
"""Checks whether GCS pubsub feature flag is enabled."""
return Config.gcs_grpc_based_pubsub()


def construct_error_message(job_id, error_type, message, timestamp):
"""Construct an ErrorTableData object.
Expand Down
6 changes: 1 addition & 5 deletions python/ray/_private/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,9 @@


def use_gcs_for_bootstrap():
from ray._private.gcs_pubsub import gcs_pubsub_enabled
from ray._raylet import Config

ret = Config.bootstrap_with_gcs()
if ret:
assert gcs_pubsub_enabled()
return ret
return Config.bootstrap_with_gcs()


def get_gcs_address_from_redis(redis) -> str:
Expand Down
36 changes: 7 additions & 29 deletions python/ray/_private/import_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,9 @@ def __init__(self, worker, mode, threads_stopped):
self.worker = worker
self.mode = mode
self.gcs_client = worker.gcs_client
if worker.gcs_pubsub_enabled:
self.subscriber = worker.gcs_function_key_subscriber
self.subscriber.subscribe()
self.exception_type = grpc.RpcError
else:
import redis

self.subscriber = worker.redis_client.pubsub()
self.subscriber.subscribe(
b"__keyspace@0__:"
+ ray._private.function_manager.make_exports_prefix(
self.worker.current_job_id
)
)
self.exception_type = redis.exceptions.ConnectionError
self.subscriber = worker.gcs_function_key_subscriber
self.subscriber.subscribe()
self.exception_type = grpc.RpcError
self.threads_stopped = threads_stopped
self.imported_collision_identifiers = defaultdict(int)
# Keep track of the number of imports that we've imported.
Expand All @@ -72,20 +60,10 @@ def _run(self):
# Exit if we received a signal that we should stop.
if self.threads_stopped.is_set():
return

if self.worker.gcs_pubsub_enabled:
key = self.subscriber.poll()
if key is None:
# subscriber has closed.
break
else:
msg = self.subscriber.get_message()
if msg is None:
self.threads_stopped.wait(timeout=0.01)
continue
if msg["type"] == "subscribe":
continue

key = self.subscriber.poll()
if key is None:
# subscriber has closed.
break
self._do_importing()
except (OSError, self.exception_type) as e:
logger.error(f"ImportThread: {e}")
Expand Down
25 changes: 4 additions & 21 deletions python/ray/_private/log_monitor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import errno
import glob
import json
import logging
import logging.handlers
import os
Expand All @@ -13,10 +12,9 @@

import ray.ray_constants as ray_constants
import ray._private.gcs_pubsub as gcs_pubsub
import ray._private.gcs_utils as gcs_utils
import ray._private.services as services
import ray._private.utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.gcs_pubsub import GcsPublisher
from ray._private.ray_logging import setup_component_logger

# Logger for this module. It should be configured at the entry point
Expand Down Expand Up @@ -91,7 +89,6 @@ class LogMonitor:
host (str): The hostname of this machine. Used to improve the log
messages published to Redis.
logs_dir (str): The directory that the log files are in.
redis_client: A client used to communicate with the Redis server.
log_filenames (set): This is the set of filenames of all files in
open_file_infos and closed_file_infos.
open_file_infos (list[LogFileInfo]): Info for all of the open files.
Expand All @@ -105,10 +102,7 @@ def __init__(self, logs_dir, redis_address, gcs_address, redis_password=None):
"""Initialize the log monitor object."""
self.ip = services.get_node_ip_address()
self.logs_dir = logs_dir
self.redis_client = None
self.publisher = None
if gcs_pubsub.gcs_pubsub_enabled():
self.publisher = gcs_pubsub.GcsPublisher(address=gcs_address)
self.publisher = gcs_pubsub.GcsPublisher(address=gcs_address)
self.log_filenames = set()
self.open_file_infos = []
self.closed_file_infos = []
Expand Down Expand Up @@ -293,12 +287,7 @@ def flush():
"actor_name": file_info.actor_name,
"task_name": file_info.task_name,
}
if self.publisher:
self.publisher.publish_logs(data)
else:
self.redis_client.publish(
gcs_utils.LOG_FILE_CHANNEL, json.dumps(data)
)
self.publisher.publish_logs(data)
anything_published = True
lines_to_publish = []

Expand Down Expand Up @@ -477,12 +466,7 @@ def run(self):
log_monitor.run()
except Exception as e:
# Something went wrong, so push an error to all drivers.
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password
)
gcs_publisher = None
if gcs_pubsub_enabled():
gcs_publisher = GcsPublisher(address=args.gcs_address)
gcs_publisher = GcsPublisher(address=args.gcs_address)
traceback_str = ray._private.utils.format_error_message(traceback.format_exc())
message = (
f"The log monitor on node {platform.node()} "
Expand All @@ -491,7 +475,6 @@ def run(self):
ray._private.utils.publish_error_to_driver(
ray_constants.LOG_MONITOR_DIED_ERROR,
message,
redis_client=redis_client,
gcs_publisher=gcs_publisher,
)
logger.error(message)
Expand Down
18 changes: 4 additions & 14 deletions python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ray.core.generated import node_manager_pb2
from ray.core.generated import node_manager_pb2_grpc
from ray._private.gcs_pubsub import (
gcs_pubsub_enabled,
GcsErrorSubscriber,
GcsLogSubscriber,
)
Expand Down Expand Up @@ -218,7 +217,6 @@ def run_string_as_driver(driver_script: str, env: Dict = None, encode: str = "ut
Returns:
The script's output.
"""

proc = subprocess.Popen(
[sys.executable, "-"],
stdin=subprocess.PIPE,
Expand Down Expand Up @@ -581,12 +579,8 @@ def get_non_head_nodes(cluster):

def init_error_pubsub():
"""Initialize redis error info pub/sub"""
if gcs_pubsub_enabled():
s = GcsErrorSubscriber(address=ray.worker.global_worker.gcs_client.address)
s.subscribe()
else:
s = ray.worker.global_worker.redis_client.pubsub(ignore_subscribe_messages=True)
s.psubscribe(gcs_utils.RAY_ERROR_PUBSUB_PATTERN)
s = GcsErrorSubscriber(address=ray.worker.global_worker.gcs_client.address)
s.subscribe()
return s


Expand Down Expand Up @@ -621,12 +615,8 @@ def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):

def init_log_pubsub():
"""Initialize redis error info pub/sub"""
if gcs_pubsub_enabled():
s = GcsLogSubscriber(address=ray.worker.global_worker.gcs_client.address)
s.subscribe()
else:
s = ray.worker.global_worker.redis_client.pubsub(ignore_subscribe_messages=True)
s.psubscribe(gcs_utils.LOG_FILE_CHANNEL)
s = GcsLogSubscriber(address=ray.worker.global_worker.gcs_client.address)
s.subscribe()
return s


Expand Down
Loading

0 comments on commit 72ef9f9

Please sign in to comment.