Skip to content

Commit

Permalink
Update docker base image, fix OOM issue.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 398320069
  • Loading branch information
Joshua Greaves authored and joshgreaves committed Sep 22, 2021
1 parent 8d6d476 commit 9687002
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ environments you intend to use before you install Dopamine:

1. Install the atari roms following the instructions from
[atari-py](https://github.com/openai/atari-py#roms).
2. `pip install atari-py` (we recommend using a [virtual environment](virtualenv)):
2. `pip install ale-py` (we recommend using a [virtual environment](virtualenv)):
3. `unzip $ROM_DIR/ROMS.zip -d $ROM_DIR && ale-import-roms $ROM_DIR/ROMS`
(replace $ROM_DIR with the directory you extracted the ROMs to).

Expand Down
6 changes: 4 additions & 2 deletions docker/core/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# If you want to use a different version of CUDA, view the available
# images here: https://hub.docker.com/r/nvidia/cuda
# Note: Jax currently supports cuda versions up to 11.3.
ARG cuda_docker_tag="11.1.1-cudnn8-devel-ubuntu20.04"
# Note:
# - Jax currently supports CUDA versions up to 11.3.
# - Tensorflow required CUDA versions after 11.2.
ARG cuda_docker_tag="11.2.2-cudnn8-devel-ubuntu20.04"
FROM nvidia/cuda:${cuda_docker_tag}

COPY . /root/dopamine/
Expand Down
6 changes: 6 additions & 0 deletions dopamine/jax/agents/sac/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@
import tensorflow as tf


logging.warning(
('Setting tf to CPU only, to avoid OOM. '
'See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html '
'for more information.'))
tf.config.set_visible_devices([], 'GPU')


gin.constant('sac_agent.IMAGE_DTYPE', onp.uint8)
gin.constant('sac_agent.STATE_DTYPE', onp.float32)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
'pygame >= 1.9.2',
'pandas >= 0.24.2',
'tf_slim >= 1.0',
'tensorflow-probability >= 0.13.0',
]

dopamine_description = (
Expand Down

0 comments on commit 9687002

Please sign in to comment.