Skip to content

Commit

Permalink
[core] move function and actor importer away from pubsub (ray-project…
Browse files Browse the repository at this point in the history
…#24132)

This PR moves function import to a lazy way. Several benefits of this:
- worker start up is faster since it doesn't need to go through all functions exported
- gcs pressure is smaller since 1) we don't need to export key and 2) all loads are done when needed.
- get rid of function table channel
  • Loading branch information
fishbone authored Apr 26, 2022
1 parent 61a9de7 commit f112b52
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 183 deletions.
68 changes: 21 additions & 47 deletions python/ray/_private/function_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,12 @@ def export(self, remote_function):
self._worker.gcs_client.internal_kv_put(
key, val, True, KV_NAMESPACE_FUNCTION_TABLE
)
self.export_key(key)

def fetch_and_register_remote_function(self, key):
"""Import a remote function."""
vals = self._worker.gcs_client.internal_kv_get(key, KV_NAMESPACE_FUNCTION_TABLE)
if vals is None:
vals = {}
return False
else:
vals = pickle.loads(vals)
fields = [
Expand Down Expand Up @@ -307,6 +306,7 @@ def f(*args, **kwargs):
self._function_execution_info[function_id] = FunctionExecutionInfo(
function=function, function_name=function_name, max_calls=max_calls
)
return True

def get_execution_info(self, job_id, function_descriptor):
"""Get the FunctionExecutionInfo of a remote function.
Expand Down Expand Up @@ -389,14 +389,23 @@ def _wait_for_function(self, function_descriptor, job_id, timeout=10):
warning_sent = False
while True:
with self.lock:
if self._worker.actor_id.is_nil() and (
function_descriptor.function_id in self._function_execution_info
):
break
elif not self._worker.actor_id.is_nil() and (
self._worker.actor_id in self._worker.actors
):
if self._worker.actor_id.is_nil():
if function_descriptor.function_id in self._function_execution_info:
break
else:
key = make_function_table_key(
b"RemoteFunction",
job_id,
function_descriptor.function_id.binary(),
)
if self.fetch_and_register_remote_function(key) is True:
break
else:
assert not self._worker.actor_id.is_nil()
# Actor loading will happen when execute_task is called.
assert self._worker.actor_id in self._worker.actors
break

if time.time() - start_time > timeout:
warning_message = (
"This worker was asked to execute a function "
Expand All @@ -419,22 +428,6 @@ def _wait_for_function(self, function_descriptor, job_id, timeout=10):
self._worker.import_thread._do_importing()
time.sleep(0.001)

def _publish_actor_class_to_key(self, key, actor_class_info):
"""Push an actor class definition to Redis.
The is factored out as a separate function because it is also called
on cached actor class definitions when a worker connects for the first
time.
Args:
key: The key to store the actor class info at.
actor_class_info: Information about the actor class.
"""
# We set the driver ID here because it may not have been available when
# the actor class was defined.
self._worker.gcs_client.internal_kv_put(
key, pickle.dumps(actor_class_info), True, KV_NAMESPACE_FUNCTION_TABLE
)
self.export_key(key)

def export_actor_class(
self, Class, actor_creation_function_descriptor, actor_method_names
):
Expand Down Expand Up @@ -493,7 +486,9 @@ def export_actor_class(
self._worker,
)

self._publish_actor_class_to_key(key, actor_class_info)
self._worker.gcs_client.internal_kv_put(
key, pickle.dumps(actor_class_info), True, KV_NAMESPACE_FUNCTION_TABLE
)
# TODO(rkn): Currently we allow actor classes to be defined
# within tasks. I tried to disable this, but it may be necessary
# because of https://github.com/ray-project/ray/issues/1146.
Expand Down Expand Up @@ -610,27 +605,6 @@ def _load_actor_class_from_gcs(self, job_id, actor_creation_function_descriptor)
job_id,
actor_creation_function_descriptor.function_id.binary(),
)
# Only wait for the actor class if it was exported from the same job.
# It will hang if the job id mismatches, since we isolate actor class
# exports from the import thread. It's important to wait since this
# guarantees import order, though we fetch the actor class directly.
# Import order isn't important across jobs, as we only need to fetch
# the class for `ray.get_actor()`.
if job_id.binary() == self._worker.current_job_id.binary():
# Wait for the actor class key to have been imported by the
# import thread. TODO(rkn): It shouldn't be possible to end
# up in an infinite loop here, but we should push an error to
# the driver if too much time is spent here.
while key not in self.imported_actor_classes:
try:
# If we're in the process of deserializing an ActorHandle
# and we hold the function_manager lock, we may be blocking
# the import_thread from loading the actor class. Use wait
# to temporarily yield control to the import thread.
self.cv.wait()
except RuntimeError:
# We don't hold the function_manager lock, just sleep
time.sleep(0.001)

# Fetch raw data from GCS.
vals = self._worker.gcs_client.internal_kv_get(key, KV_NAMESPACE_FUNCTION_TABLE)
Expand Down
3 changes: 3 additions & 0 deletions python/ray/_private/import_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(self, worker, mode, threads_stopped):
self.num_imported = 0
# Protect writes to self.num_imported.
self._lock = threading.Lock()
# Try to load all FunctionsToRun so that these functions will be
# run before accepting tasks.
self._do_importing()

def start(self):
"""Start the import thread."""
Expand Down
34 changes: 0 additions & 34 deletions python/ray/tests/test_basic_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest

import ray.cluster_utils
from ray._private.gcs_pubsub import GcsFunctionKeySubscriber
from ray._private.test_utils import wait_for_condition
from ray.autoscaler._private.constants import RAY_PROCESSES
from pathlib import Path
Expand Down Expand Up @@ -99,39 +98,6 @@ def get_num_workers():
time.sleep(0.1)


def test_function_unique_export(ray_start_regular):
@ray.remote
def f():
pass

@ray.remote
def g():
ray.get(f.remote())

subscriber = GcsFunctionKeySubscriber(
address=ray.worker.global_worker.gcs_client.address
)
subscriber.subscribe()

ray.get(g.remote())

# Poll pubsub channel for messages generated from running task g().
num_exports = 0
while True:
key = subscriber.poll(timeout=1)
if key is None:
break
else:
num_exports += 1
print(f"num_exports after running g(): {num_exports}")
assert num_exports > 0, "Function export notification is not received"

ray.get([g.remote() for _ in range(5)])

key = subscriber.poll(timeout=1)
assert key is None, f"Unexpected function key export: {key}"


def test_function_import_without_importer_thread(shutdown_only):
"""Test that without background importer thread, dependencies can still be
imported in workers."""
Expand Down
37 changes: 20 additions & 17 deletions python/ray/tests/test_basic_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,34 +124,37 @@ def test_internal_kv(ray_start_regular):
kv._internal_kv_list("@namespace_abc", namespace="n")


def test_run_on_all_workers(ray_start_regular):
def test_run_on_all_workers(ray_start_regular, tmp_path):
# This test is to ensure run_function_on_all_workers are executed
# on all workers.
@ray.remote
class Actor:
def __init__(self):
self.jobs = []

def record(self, job_id=None):
if job_id is not None:
self.jobs.append(job_id)
return self.jobs

a = Actor.options(name="recorder", namespace="n").remote() # noqa: F841
driver_script = """
lock_file = tmp_path / "lock"
data_file = tmp_path / "data"
driver_script = f"""
import ray
from filelock import FileLock
from pathlib import Path
import pickle
lock_file = r"{str(lock_file)}"
data_file = Path(r"{str(data_file)}")
def init_func(worker_info):
a = ray.get_actor("recorder", namespace="n")
a.record.remote(worker_info['worker'].worker_id)
with FileLock(lock_file):
if data_file.exists():
old = pickle.loads(data_file.read_bytes())
else:
old = []
old.append(worker_info['worker'].worker_id)
data_file.write_bytes(pickle.dumps(old))
ray.worker.global_worker.run_function_on_all_workers(init_func)
ray.init(address='auto')
@ray.remote
def ready():
a = ray.get_actor("recorder", namespace="n")
assert ray.worker.global_worker.worker_id in ray.get(a.record.remote())
with FileLock(lock_file):
worker_ids = pickle.loads(data_file.read_bytes())
assert ray.worker.global_worker.worker_id in worker_ids
ray.get(ready.remote())
"""
Expand Down
85 changes: 0 additions & 85 deletions python/ray/tests/test_failure_2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import os
import signal
import sys
Expand Down Expand Up @@ -105,90 +104,6 @@ def g(remote_waits, nested_waits):
p.close()


def test_warning_for_many_duplicate_remote_functions_and_actors(shutdown_only):
ray.init(num_cpus=1)

@ray.remote
def create_remote_function():
@ray.remote
def g():
return 1

return ray.get(g.remote())

for _ in range(ray_constants.DUPLICATE_REMOTE_FUNCTION_THRESHOLD - 1):
ray.get(create_remote_function.remote())

import io

log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)

# TODO(rkn): It's terrible to have to rely on this implementation detail,
# the fact that the warning comes from ray._private.import_thread.logger.
# However, I didn't find a good way to capture the output for all loggers
# simultaneously.
ray._private.import_thread.logger.addHandler(ch)

ray.get(create_remote_function.remote())

start_time = time.time()
while time.time() < start_time + 10:
log_contents = log_capture_string.getvalue()
if len(log_contents) > 0:
break

ray._private.import_thread.logger.removeHandler(ch)

assert "remote function" in log_contents
assert (
"has been exported {} times.".format(
ray_constants.DUPLICATE_REMOTE_FUNCTION_THRESHOLD
)
in log_contents
)

# Now test the same thing but for actors.

@ray.remote
def create_actor_class():
# Require a GPU so that the actor is never actually created and we
# don't spawn an unreasonable number of processes.
@ray.remote(num_gpus=1)
class Foo:
pass

Foo.remote()

for _ in range(ray_constants.DUPLICATE_REMOTE_FUNCTION_THRESHOLD - 1):
ray.get(create_actor_class.remote())

log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)

# TODO(rkn): As mentioned above, it's terrible to have to rely on this
# implementation detail.
ray._private.import_thread.logger.addHandler(ch)

ray.get(create_actor_class.remote())

start_time = time.time()
while time.time() < start_time + 10:
log_contents = log_capture_string.getvalue()
if len(log_contents) > 0:
break

ray._private.import_thread.logger.removeHandler(ch)

assert "actor" in log_contents
assert (
"has been exported {} times.".format(
ray_constants.DUPLICATE_REMOTE_FUNCTION_THRESHOLD
)
in log_contents
)


# Note that this test will take at least 10 seconds because it must wait for
# the monitor to detect enough missed heartbeats.
def test_warning_for_dead_node(ray_start_cluster_2_nodes, error_pubsub):
Expand Down

0 comments on commit f112b52

Please sign in to comment.