Skip to content

Commit

Permalink
[rllib] Use 64-byte aligned memory when concatenating arrays (ray-pro…
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Mar 26, 2019
1 parent c68eea6 commit 8ee240f
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ci/jenkins_tests/run_rllib_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
--run IMPALA \
--stop='{"timesteps_total": 40000}' \
--ray-object-store-memory=500000000 \
--config '{"num_workers": 1, "num_gpus": 0, "num_envs_per_worker": 64, "sample_batch_size": 50, "train_batch_size": 50, "learner_queue_size": 1}'
--config '{"num_workers": 1, "num_gpus": 0, "num_envs_per_worker": 32, "sample_batch_size": 50, "train_batch_size": 50, "learner_queue_size": 1}'

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/agents/impala/vtrace_test.py
Expand Down
4 changes: 4 additions & 0 deletions python/ray/rllib/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def collect_episodes(local_evaluator,
collected, _ = ray.wait(
pending, num_returns=len(pending), timeout=timeout_seconds * 1.0)
num_metric_batches_dropped = len(pending) - len(collected)
if pending and len(collected) == 0:
raise ValueError(
"Timed out waiting for metrics from workers. You can configure "
"this timeout with `collect_metrics_timeout`.")

metric_lists = ray.get(collected)
metric_lists.append(local_evaluator.get_metrics())
Expand Down
5 changes: 3 additions & 2 deletions python/ray/rllib/evaluation/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.memory import concat_aligned

# Defaults policy id for single agent environments
DEFAULT_POLICY_ID = "default"
Expand Down Expand Up @@ -104,7 +105,7 @@ def concat_samples(samples):
out = {}
samples = [s for s in samples if s.count > 0]
for k in samples[0].keys():
out[k] = np.concatenate([s[k] for s in samples])
out[k] = concat_aligned([s[k] for s in samples])
return SampleBatch(out)

@PublicAPI
Expand All @@ -121,7 +122,7 @@ def concat(self, other):
assert self.keys() == other.keys(), "must have same columns"
out = {}
for k in self.keys():
out[k] = np.concatenate([self[k], other[k]])
out[k] = concat_aligned([self[k], other[k]])
return SampleBatch(out)

@PublicAPI
Expand Down
51 changes: 51 additions & 0 deletions python/ray/rllib/utils/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np


def aligned_array(size, dtype, align=64):
"""Returns an array of a given size that is 64-byte aligned.
The returned array can be efficiently copied into GPU memory by TensorFlow.
"""

n = size * dtype.itemsize
empty = np.empty(n + (align - 1), dtype=np.uint8)
data_align = empty.ctypes.data % align
offset = 0 if data_align == 0 else (align - data_align)
output = empty[offset:offset + n].view(dtype)

assert len(output) == size, len(output)
assert output.ctypes.data % align == 0, output.ctypes.data
return output


def concat_aligned(items):
"""Concatenate arrays, ensuring the output is 64-byte aligned.
We only align float arrays; other arrays are concatenated as normal.
This should be used instead of np.concatenate() to improve performance
when the output array is likely to be fed into TensorFlow.
"""

if len(items) == 0:
return []
elif len(items) == 1:
# we assume the input is aligned. In any case, it doesn't help
# performance to force align it since that incurs a needless copy.
return items[0]
elif (isinstance(items[0], np.ndarray)
and items[0].dtype in [np.float32, np.float64, np.uint8]):
dtype = items[0].dtype
flat = aligned_array(sum(s.size for s in items), dtype)
batch_dim = sum(s.shape[0] for s in items)
new_shape = (batch_dim, ) + items[0].shape[1:]
output = flat.reshape(new_shape)
assert output.ctypes.data % 64 == 0, output.ctypes.data
np.concatenate(items, out=output)
return output
else:
return np.concatenate(items)

0 comments on commit 8ee240f

Please sign in to comment.