Skip to content

Commit

Permalink
Fix GPU tests of world model
Browse files Browse the repository at this point in the history
Summary: Fix https://ci.pytorch.org/jenkins/job/horizon-builds/job/horizon-xenial-cuda9-cudnn7-py3-build-test/417//console

Reviewed By: MisterTea

Differential Revision: D15163755

fbshipit-source-id: f619fee01166ba99bea3dcc05e740da0cf87efaa
  • Loading branch information
czxttkl authored and facebook-github-bot committed May 1, 2019
1 parent ade25da commit c9b4c94
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions ml/rl/models/mdn_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,22 @@ def deque_sample(self, indices):
s = self.replay_memory[i]
yield s.state, s.action, s.next_state, s.reward, s.not_terminal

def sample_memories(self, batch_size, batch_first=False):
def sample_memories(self, batch_size, use_gpu=False, batch_first=False):
"""
:param batch_size: number of samples to return
:param use_gpu: whether to put samples on gpu
:param batch_first: If True, the first dimension of data is batch_size.
If False (default), the first dimension is SEQ_LEN. Therefore,
state's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example. By default,
MDN-RNN consumes data with SEQ_LEN as the first dimension.
"""
sample_indices = np.random.randint(self.memory_size, size=batch_size)
device = torch.device("cuda") if use_gpu else torch.device("cpu")
# state/next state shape: batch_size x seq_len x state_dim
# action shape: # state shape: batch_size x seq_len x action_dim
# reward/not_terminal shape: batch_size x seq_len
state, action, next_state, reward, not_terminal = map(
lambda x: torch.tensor(x, dtype=torch.float),
lambda x: torch.tensor(x, dtype=torch.float, device=device),
zip(*self.deque_sample(sample_indices)),
)

Expand Down
2 changes: 1 addition & 1 deletion ml/rl/test/world_model/test_mdnrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _test_mdnrnn_simulate_world(self, use_gpu=False):
for e in range(num_epochs):
for i in range(num_batch):
training_batch = replay_buffer.sample_memories(
batch_size, batch_first=use_gpu
batch_size, use_gpu=use_gpu, batch_first=use_gpu
)
losses = trainer.train(training_batch, batch_first=use_gpu)
logger.info(
Expand Down

0 comments on commit c9b4c94

Please sign in to comment.