Skip to content

Commit

Permalink
Add MSE loss to dqn_agent, and add loss tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 370113976
  • Loading branch information
psc-g authored and joshgreaves committed Apr 30, 2021
1 parent fb95230 commit ad4732e
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 7 deletions.
22 changes: 17 additions & 5 deletions dopamine/jax/agents/dqn/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,21 @@ def create_optimizer(name='adam', learning_rate=6.25e-5, beta1=0.9, beta2=0.999,
raise ValueError('Unsupported optimizer {}'.format(name))


# TODO(psc): Refactor into a separato losses module.
def huber_loss(targets, predictions, delta=1.0):
x = jnp.abs(targets - predictions)
return jnp.where(x <= delta,
0.5 * x**2,
0.5 * delta**2 + delta * (x - delta))


@functools.partial(jax.jit, static_argnums=(0, 8))
def mse_loss(targets, predictions):
return jnp.power((targets - predictions), 2)


@functools.partial(jax.jit, static_argnums=(0, 8, 9))
def train(network_def, target_params, optimizer, states, actions, next_states,
rewards, terminals, cumulative_gamma):
rewards, terminals, cumulative_gamma, loss_type='huber'):
"""Run the training step."""
online_params = optimizer.target
def loss_fn(params, target):
Expand All @@ -98,7 +103,9 @@ def q_online(state):
q_values = jax.vmap(q_online)(states).q_values
q_values = jnp.squeeze(q_values)
replay_chosen_q = jax.vmap(lambda x, y: x[y])(q_values, actions)
loss = jnp.mean(jax.vmap(huber_loss)(target, replay_chosen_q))
loss = jnp.where(loss_type == 'huber',
jnp.mean(jax.vmap(huber_loss)(target, replay_chosen_q)),
jnp.mean(jax.vmap(mse_loss)(target, replay_chosen_q)))
return loss

def q_target(state):
Expand Down Expand Up @@ -225,7 +232,8 @@ def __init__(self,
summary_writer=None,
summary_writing_frequency=500,
allow_partial_reload=False,
seed=None):
seed=None,
loss_type='huber'):
"""Initializes the agent and constructs the necessary components.
Note: We are using the Adam optimizer by default for JaxDQN, which differs
Expand Down Expand Up @@ -264,6 +272,7 @@ def __init__(self,
(for instance, only the network parameters).
seed: int, a seed for DQN's internal RNG, used for initialization and
sampling actions. If None, will use the current time in nanoseconds.
loss_type: str, whether to use Huber or MSE loss during training.
"""
assert isinstance(observation_shape, tuple)
seed = int(time.time() * 1e6) if seed is None else seed
Expand All @@ -281,6 +290,7 @@ def __init__(self,
logging.info('\t max_tf_checkpoints_to_keep: %d',
max_tf_checkpoints_to_keep)
logging.info('\t seed: %d', seed)
logging.info('\t loss_type: %s', loss_type)

self.num_actions = num_actions
self.observation_shape = tuple(observation_shape)
Expand All @@ -302,6 +312,7 @@ def __init__(self,
self.summary_writer = summary_writer
self.summary_writing_frequency = summary_writing_frequency
self.allow_partial_reload = allow_partial_reload
self._loss_type = loss_type

self._rng = jax.random.PRNGKey(seed)
state_shape = self.observation_shape + (stack_size,)
Expand Down Expand Up @@ -476,7 +487,8 @@ def _train_step(self):
self.replay_elements['next_state'],
self.replay_elements['reward'],
self.replay_elements['terminal'],
self.cumulative_gamma)
self.cumulative_gamma,
self._loss_type)
if (self.summary_writer is not None and
self.training_steps > 0 and
self.training_steps % self.summary_writing_frequency == 0):
Expand Down
46 changes: 44 additions & 2 deletions tests/dopamine/jax/agents/dqn/dqn_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

import os
import shutil


from typing import Optional, Union

from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from dopamine.discrete_domains import atari_lib
from dopamine.jax.agents.dqn import dqn_agent
from dopamine.utils import test_utils
Expand All @@ -36,6 +36,48 @@
FLAGS = flags.FLAGS


class LossesTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(testcase_name='BelowDelta1d',
target=1.0, prediction=0.0, delta=1.0,
expected=onp.array(0.5)),
dict(testcase_name='AboveDelta1d',
target=1.0, prediction=0.0, delta=0.5,
expected=onp.array(0.375)),
dict(testcase_name='MixedArraysDefaultDelta',
target=onp.ones(5), prediction=onp.array([0., 1., 2., 3., 4.]),
delta=None,
expected=onp.array([0.5, 0., 0.5, 1.5, 2.5])),
dict(testcase_name='MixedArraysSetDelta',
target=onp.ones(5), prediction=onp.array([0., 1., 2., 3., 4.]),
delta=2.0,
expected=onp.array([0.5, 0., 0.5, 2.0, 4.0])))
def testHuberLoss(self,
target: Union[float, onp.array],
prediction: Union[float, onp.array],
delta: Optional[float],
expected: Union[float, onp.array]):
if delta is None:
actual = dqn_agent.huber_loss(target, prediction)
else:
actual = dqn_agent.huber_loss(target, prediction, delta=delta)
onp.testing.assert_equal(actual, expected)

@parameterized.named_parameters(
dict(testcase_name='1DParameters',
target=2.0, prediction=0.0, expected=onp.array(4.0)),
dict(testcase_name='ArrayParameters',
target=onp.ones(5), prediction=onp.array([0., 1., 2., 3., 4.]),
expected=onp.array([1.0, 0.0, 1.0, 4.0, 9.0])))
def testMSELoss(self,
target: Union[float, onp.array],
prediction: Union[float, onp.array],
expected: Union[float, onp.array]):
actual = dqn_agent.mse_loss(target, prediction)
onp.testing.assert_equal(actual, expected)


class DQNAgentTest(absltest.TestCase):

def setUp(self):
Expand Down

0 comments on commit ad4732e

Please sign in to comment.