Skip to content

Commit

Permalink
Fix input-processing when preprocess_fn is explicitly passed.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 408423178
  • Loading branch information
agarwl authored and psc-g committed Nov 9, 2021
1 parent ff80e75 commit 95308e7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def step(self, reward, observation):

self._rng, self.action = select_action(self.network_def,
self.online_params,
self.state,
self.preprocess_fn(self.state),
self._rng,
self.num_quantile_samples,
self.num_actions,
Expand Down Expand Up @@ -419,9 +419,9 @@ def _train_step(self):
self.target_network_params,
self.optimizer,
self.optimizer_state,
self.replay_elements['state'],
self.preprocess_fn(self.replay_elements['state']),
self.replay_elements['action'],
self.replay_elements['next_state'],
self.preprocess_fn(self.replay_elements['next_state']),
self.replay_elements['reward'],
self.replay_elements['terminal'],
self.num_tau_samples,
Expand Down
4 changes: 2 additions & 2 deletions dopamine/jax/agents/quantile/quantile_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ def _train_step(self):
self.target_network_params,
self.optimizer,
self.optimizer_state,
self.replay_elements['state'],
self.preprocess_fn(self.replay_elements['state']),
self.replay_elements['action'],
self.replay_elements['next_state'],
self.preprocess_fn(self.replay_elements['next_state']),
self.replay_elements['reward'],
self.replay_elements['terminal'],
self._kappa,
Expand Down
6 changes: 3 additions & 3 deletions dopamine/jax/agents/rainbow/rainbow_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def step(self, reward, observation):

self._rng, self.action = select_action(self.network_def,
self.online_params,
self.state,
self.preprocess_fn(self.state),
self._rng,
self.num_actions,
self.eval_mode,
Expand Down Expand Up @@ -398,9 +398,9 @@ def _train_step(self):
self.target_network_params,
self.optimizer,
self.optimizer_state,
self.replay_elements['state'],
self.preprocess_fn(self.replay_elements['state']),
self.replay_elements['action'],
self.replay_elements['next_state'],
self.preprocess_fn(self.replay_elements['next_state']),
self.replay_elements['reward'],
self.replay_elements['terminal'],
loss_weights,
Expand Down

0 comments on commit 95308e7

Please sign in to comment.