Skip to content

Commit

Permalink
Rename async_queue_depth -> num_async (ray-project#8207)
Browse files Browse the repository at this point in the history
* rename

* lint
  • Loading branch information
ericl authored May 5, 2020
1 parent f48da50 commit ee0eb44
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 17 deletions.
2 changes: 1 addition & 1 deletion python/ray/tests/test_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def test_gather_async(ray_start_regular_shared):

def test_gather_async_queue(ray_start_regular_shared):
it = from_range(100)
it = it.gather_async(async_queue_depth=4)
it = it.gather_async(num_async=4)
assert sorted(it) == list(range(100))


Expand Down
8 changes: 4 additions & 4 deletions python/ray/util/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,14 +415,14 @@ def base_iterator(timeout=None):
name = "{}.batch_across_shards()".format(self)
return LocalIterator(base_iterator, SharedMetrics(), name=name)

def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
def gather_async(self, num_async=1) -> "LocalIterator[T]":
"""Returns a local iterable for asynchronous iteration.
New items will be fetched from the shards asynchronously as soon as
the previous one is computed. Items arrive in non-deterministic order.
Arguments:
async_queue_depth (int): The max number of async requests in flight
num_async (int): The max number of async requests in flight
per actor. Increasing this improves the amount of pipeline
parallelism in the iterator.
Expand All @@ -436,7 +436,7 @@ def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
... 1
"""

if async_queue_depth < 1:
if num_async < 1:
raise ValueError("queue depth must be positive")

# Forward reference to the returned iterator.
Expand All @@ -448,7 +448,7 @@ def base_iterator(timeout=None):
actor_set.init_actors()
all_actors.extend(actor_set.actors)
futures = {}
for _ in range(async_queue_depth):
for _ in range(num_async):
for a in all_actors:
futures[a.par_iter_next.remote()] = a
while futures:
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __call__(self, item):
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our replay buffer actors. Update
# the weights of the worker that generated the batch.
rollouts = ParallelRollouts(workers, mode="async", async_queue_depth=2)
rollouts = ParallelRollouts(workers, mode="async", num_async=2)
store_op = rollouts \
.for_each(StoreToReplayBuffer(actors=replay_actors)) \
.zip_with_source_actor() \
Expand All @@ -154,7 +154,7 @@ def __call__(self, item):

# (2) Read experiences from the replay buffer actors and send to the
# learner thread via its in-queue.
replay_op = Replay(actors=replay_actors, async_queue_depth=4) \
replay_op = Replay(actors=replay_actors, num_async=4) \
.zip_with_source_actor() \
.for_each(Enqueue(learner_thread.inqueue))

Expand Down
8 changes: 4 additions & 4 deletions rllib/execution/replay_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __call__(self, batch: SampleBatchType):
def Replay(*,
local_buffer: LocalReplayBuffer = None,
actors: List["ActorHandle"] = None,
async_queue_depth=4):
num_async=4):
"""Replay experiences from the given buffer or actors.
This should be combined with the StoreToReplayActors operation using the
Expand All @@ -63,7 +63,7 @@ def Replay(*,
and replay_actors can be specified.
actors (list): List of replay actors. Only one of this and
local_buffer can be specified.
async_queue_depth (int): In async mode, the max number of async
num_async (int): In async mode, the max number of async
requests in flight per actor.
Examples:
Expand All @@ -79,8 +79,8 @@ def Replay(*,

if actors:
replay = from_actors(actors)
return replay.gather_async(async_queue_depth=async_queue_depth).filter(
lambda x: x is not None)
return replay.gather_async(
num_async=num_async).filter(lambda x: x is not None)

def gen_replay(_):
while True:
Expand Down
10 changes: 4 additions & 6 deletions rllib/execution/rollout_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
logger = logging.getLogger(__name__)


def ParallelRollouts(workers: WorkerSet,
*,
mode="bulk_sync",
async_queue_depth=1) -> LocalIterator[SampleBatch]:
def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync",
num_async=1) -> LocalIterator[SampleBatch]:
"""Operator to collect experiences in parallel from rollout workers.
If there are no remote workers, experiences will be collected serially from
Expand All @@ -36,7 +34,7 @@ def ParallelRollouts(workers: WorkerSet,
- In 'raw' mode, the ParallelIterator object is returned directly
and the caller is responsible for implementing gather and
updating the timesteps counter.
async_queue_depth (int): In async mode, the max number of async
num_async (int): In async mode, the max number of async
requests in flight per actor.
Returns:
Expand Down Expand Up @@ -83,7 +81,7 @@ def sampler(_):
.for_each(report_timesteps)
elif mode == "async":
return rollouts.gather_async(
async_queue_depth=async_queue_depth).for_each(report_timesteps)
num_async=num_async).for_each(report_timesteps)
elif mode == "raw":
return rollouts
else:
Expand Down

0 comments on commit ee0eb44

Please sign in to comment.