Skip to content

Commit

Permalink
[tune] Use generic _ObjectCache for actor reuse (ray-project#33045)
Browse files Browse the repository at this point in the history
Actor reuse is currently implemented within the RayTrialExecutor and tightly coupled to its concepts. To simplify the code, this PR introduces a new internal _ObjectCache class that generically caches objects (e.g. Actors+PGs) given a grouping key (e.g. resource requests).

By moving the caching logic into a separate component, we can write better unit tests to ensure actor reuse works correctly and as expected.

Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
  • Loading branch information
krfricke authored Mar 7, 2023
1 parent 340f7b2 commit c0f6068
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 67 deletions.
8 changes: 8 additions & 0 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ py_test(
tags = ["team:ml", "exclusive"],
)

py_test(
name = "test_util_object_cache",
size = "small",
srcs = ["tests/test_util_object_cache.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive"],
)

py_test(
name = "test_syncer",
size = "medium",
Expand Down
84 changes: 29 additions & 55 deletions python/ray/tune/execution/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import random
import time
import traceback
from collections import deque, defaultdict, Counter
from collections import deque
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import Callable, Dict, Iterable, List, Optional, Set, Union, Tuple
from typing import Callable, Dict, Iterable, Optional, Set, Union

import ray
from ray.actor import ActorHandle
from ray.air import Checkpoint, AcquiredResources, ResourceRequest
from ray.air import Checkpoint, AcquiredResources
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
from ray.air.constants import (
COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
Expand All @@ -35,6 +35,7 @@
from ray.tune.result import STDERR_FILE, STDOUT_FILE, TRIAL_INFO
from ray.tune.experiment.trial import Trial, _Location, _TrialInfo
from ray.tune.utils import warn_if_slow
from ray.tune.utils.object_cache import _ObjectCache
from ray.tune.utils.resource_updater import _ResourceUpdater
from ray.tune.trainable.util import TrainableUtil
from ray.util import log_once
Expand Down Expand Up @@ -234,13 +235,10 @@ def __init__(
# Actor re-use.
# For details, see docstring of `_maybe_cache_trial_actor()`
self._reuse_actors = reuse_actors
self._resource_request_to_cached_actors: Dict[
ResourceRequest, List[Tuple[ray.actor.ActorHandle, AcquiredResources]]
] = defaultdict(list)
self._actor_cache = _ObjectCache(may_keep_one=True)

# Trials for which we requested resources
self._staged_trials = set() # Staged trials
self._staged_resources = Counter() # Resources of staged trials
self._trial_to_acquired_resources: Dict[Trial, AcquiredResources] = {}

# Result buffer
Expand All @@ -261,7 +259,7 @@ def __init__(
def setup(
self, max_pending_trials: int, trainable_kwargs: Optional[Dict] = None
) -> None:
if len(self._resource_request_to_cached_actors) > 0:
if self._actor_cache.num_cached_objects:
logger.warning(
"Cannot update maximum number of queued actors for reuse "
"during a run."
Expand Down Expand Up @@ -327,7 +325,7 @@ def _stage_and_update_status(self, trials: Iterable[Trial]):
resource_request = trial.placement_group_factory

self._staged_trials.add(trial)
self._staged_resources[trial.placement_group_factory] += 1
self._actor_cache.increase_max(resource_request)
self._resource_manager.request_resources(resource_request=resource_request)

self._resource_manager.update_state()
Expand All @@ -344,7 +342,7 @@ def get_ready_trial(self) -> Optional[Trial]:
for trial in self._staged_trials:
resource_request = trial.placement_group_factory
# If we have a cached actor for these resources, return
if self._resource_request_to_cached_actors[resource_request]:
if self._actor_cache.has_cached_object(resource_request):
return trial

# If the resources are available from the resource manager, return
Expand All @@ -360,12 +358,13 @@ def _maybe_use_cached_actor(self, trial, logger_creator) -> Optional[ActorHandle
return None

resource_request = trial.placement_group_factory
if not self._resource_request_to_cached_actors[resource_request]:

if not self._actor_cache.has_cached_object(resource_request):
return None

actor, acquired_resources = self._resource_request_to_cached_actors[
actor, acquired_resources = self._actor_cache.pop_cached_object(
resource_request
].pop(0)
)

logger.debug(f"Trial {trial}: Reusing cached actor " f"{actor}")

Expand Down Expand Up @@ -541,7 +540,7 @@ def _unstage_trial_with_resources(self, trial: Trial):
# Case 1: The trial we started was staged. Just remove it
if trial in self._staged_trials:
self._staged_trials.remove(trial)
self._staged_resources[trial.placement_group_factory] -= 1
self._actor_cache.decrease_max(trial.placement_group_factory)
return

# Case 2: We staged a trial "A" with the same resources, but our trial "B"
Expand All @@ -560,7 +559,7 @@ def _unstage_trial_with_resources(self, trial: Trial):

if candidate_trial:
self._staged_trials.remove(candidate_trial)
self._staged_resources[candidate_trial.placement_group_factory] -= 1
self._actor_cache.decrease_max(candidate_trial.placement_group_factory)
return

raise RuntimeError(
Expand Down Expand Up @@ -593,16 +592,8 @@ def _maybe_cache_trial_actor(self, trial: Trial) -> bool:
acquired_resources = self._trial_to_acquired_resources[trial]
cached_resource_request = acquired_resources.resource_request

staged_resource_count = self._count_staged_resources()
if (
# If we have at least one cached actor already
any(v for v in self._resource_request_to_cached_actors.values())
# and we haven't requested resources for an actor with the
# same resources as the actor we want to cache
and len(self._resource_request_to_cached_actors[cached_resource_request])
>= staged_resource_count[cached_resource_request]
# then we don't have an immediate need for the actor and don't
# want to cache it.
if not self._actor_cache.cache_object(
cached_resource_request, (trial.runner, acquired_resources)
):
logger.debug(
f"Could not cache actor of trial {trial} for "
Expand All @@ -613,9 +604,6 @@ def _maybe_cache_trial_actor(self, trial: Trial) -> bool:

logger.debug(f"Caching actor of trial {trial} for re-use")

self._resource_request_to_cached_actors[cached_resource_request].append(
(trial.runner, acquired_resources)
)
self._trial_to_acquired_resources.pop(trial)

trial.set_runner(None)
Expand Down Expand Up @@ -833,7 +821,7 @@ def has_resources_for_trial(self, trial: Trial) -> bool:

return (
trial in self._staged_trials
or self._resource_request_to_cached_actors[resource_request]
or self._actor_cache.has_cached_object(resource_request)
or len(self._staged_trials) < self._max_staged_actors
or self._resource_manager.has_resources_ready(resource_request)
)
Expand Down Expand Up @@ -861,9 +849,6 @@ def on_step_end(self, search_ended: bool = False) -> None:
self._cleanup_cached_actors(search_ended=search_ended)
self._do_force_trial_cleanup()

def _count_staged_resources(self):
return self._staged_resources

def _cleanup_cached_actors(
self, search_ended: bool = False, force_all: bool = False
):
Expand Down Expand Up @@ -902,21 +887,16 @@ def _cleanup_cached_actors(
# (if the search ended).
return

staged_resources = self._count_staged_resources()

for resource_request, actors in self._resource_request_to_cached_actors.items():
while len(actors) > staged_resources.get(resource_request, 0) or (
force_all and len(actors)
):
actor, acquired_resources = actors[-1]
actors.pop()
future = actor.stop.remote()
self._futures[future] = (
_ExecutorEventType.STOP_RESULT,
acquired_resources,
)
if self._trial_cleanup: # force trial cleanup within a deadline
self._trial_cleanup.add(future)
for (actor, acquired_resources) in self._actor_cache.flush_cached_objects(
force_all=force_all
):
future = actor.stop.remote()
self._futures[future] = (
_ExecutorEventType.STOP_RESULT,
acquired_resources,
)
if self._trial_cleanup: # force trial cleanup within a deadline
self._trial_cleanup.add(future)

def _resolve_stop_event(
self,
Expand Down Expand Up @@ -1196,18 +1176,12 @@ def get_next_executor_event(
# when next_trial_exists and there are cached resources
###################################################################
# There could be existing PGs from either
# `self._resource_request_to_cached_actors`
# `self._actor_cache`
# or from ready trials. If so and if there is indeed
# a next trial to run, we return `PG_READY` future for trial
# runner. The next trial can then be scheduled on this PG.
if next_trial_exists:
if (
sum(
len(cached)
for cached in self._resource_request_to_cached_actors.values()
)
> 0
):
if self._actor_cache.num_cached_objects > 0:
return _ExecutorEvent(_ExecutorEventType.PG_READY)
# TODO(xwjiang): Expose proper API when we decide to do
# ActorPool abstraction.
Expand Down
9 changes: 1 addition & 8 deletions python/ray/tune/tests/test_ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,14 +648,7 @@ def train(config):
executor._stage_and_update_status([trial1, trial2, trial3])
executor.pause_trial(trial1) # Caches the PG

assert (
len(
executor._resource_request_to_cached_actors[
trial1.placement_group_factory
]
)
== 1
)
assert executor._actor_cache.num_cached_objects == 1

# Second trial remains staged, it will only be removed from staging when it
# is started
Expand Down
5 changes: 1 addition & 4 deletions python/ray/tune/tests/test_trial_runner_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,7 @@ def on_step_end(self, iteration, trials, **info):
len(s) for s in resource_manager._request_to_ready_pgs.values()
)
num_in_use = len(resource_manager._acquired_pgs)
num_cached = sum(
len(a)
for a in trial_executor._resource_request_to_cached_actors.values()
)
num_cached = trial_executor._actor_cache.num_cached_objects

total_num_tracked = num_staging + num_ready + num_in_use + num_cached

Expand Down
124 changes: 124 additions & 0 deletions python/ray/tune/tests/test_util_object_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import pytest

from ray.tune.utils.object_cache import _ObjectCache


@pytest.mark.parametrize("eager", [False, True])
def test_no_may_keep_one(eager):
"""Test object caching.
- After init, no objects are cached (as max cached is 0), except when eager caching
- After increasing max to 2, up to 2 objects are cached
- Decreasing max objects will evict them on flush
"""
cache = _ObjectCache(may_keep_one=eager)

# max(A) = 0, so we we only cache when eager caching
assert cache.cache_object("A", 1) == eager
assert cache.num_cached_objects == int(eager)

# Set max(A) = 2
cache.increase_max("A", 2)

# max(A) = 2, so we cache up to two objects
if not eager:
assert cache.cache_object("A", 1)

assert cache.cache_object("A", 2)
assert not cache.cache_object("A", 3)

assert cache.num_cached_objects == 2

# Nothing has to be evicted
assert not list(cache.flush_cached_objects())

# Set max(A) = 1, so we have one object too much
cache.decrease_max("A", 1)

# First cached object is evicted
assert list(cache.flush_cached_objects()) == [1]
assert cache.num_cached_objects == 1

# Set max(A) = 0
cache.decrease_max("A", 1)

# Second cached object is evicted if not eager caching
assert list(cache.flush_cached_objects()) == ([2] if not eager else [])
assert cache.num_cached_objects == (0 if not eager else 1)


@pytest.mark.parametrize("eager", [False, True])
def test_multi(eager):
"""Test caching with multiple objects"""
cache = _ObjectCache(may_keep_one=eager)

# max(A) = 0, so we we only cache when eager caching
assert cache.cache_object("A", 1) == eager
assert cache.num_cached_objects == int(eager)

# max(B) = 0, so no caching
assert not cache.cache_object("B", 5)
assert cache.num_cached_objects == int(eager)

# Increase maximums levels
cache.increase_max("A", 1)
cache.increase_max("B", 1)

# Cache objects (A is already cached if eager)
assert cache.cache_object("A", 1) != eager
assert cache.cache_object("B", 5)

# No further objects can be cached
assert not cache.cache_object("A", 2)
assert not cache.cache_object("B", 6)

assert cache.num_cached_objects == 2

# Decrease
cache.decrease_max("A", 1)

# Evict A object
assert list(cache.flush_cached_objects()) == [1]

cache.decrease_max("B", 1)

# If eager, keep B object, otherwise, evict B
assert list(cache.flush_cached_objects()) == ([5] if not eager else [])
assert cache.num_cached_objects == (0 if not eager else 1)


def test_multi_eager_other():
"""On eager caching, only cache an object if no other object is expected.
- Expect up to one cached A object
- Try to cache object B --> doesn't get cached
- Remove expectation for A object
- Try to cache object B --> get's cached
"""
cache = _ObjectCache(may_keep_one=True)

cache.increase_max("A", 1)
assert not cache.cache_object("B", 2)

cache.decrease_max("A", 1)
assert cache.cache_object("B", 3)


@pytest.mark.parametrize("eager", [False, True])
def test_force_all(eager):
"""Assert that force_all=True will always evict all object."""
cache = _ObjectCache(may_keep_one=eager)

cache.increase_max("A", 2)

assert cache.cache_object("A", 1)
assert cache.cache_object("A", 2)

assert list(cache.flush_cached_objects(force_all=True)) == [1, 2]
assert cache.num_cached_objects == 0


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
Loading

0 comments on commit c0f6068

Please sign in to comment.