From b970e2d4d6d25899fa0dce6dc88d99e24f642fc2 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Mon, 13 May 2019 15:21:02 -0400 Subject: [PATCH] fix bug in pg_loss --- texar/losses/__init__.py | 1 + texar/losses/losses_utils.py | 4 +- texar/losses/mle_losses_test.py | 2 +- texar/losses/pg_losses.py | 4 +- texar/losses/pg_losses_test.py | 121 ++++++++++++++++++++++++++++++++ 5 files changed, 127 insertions(+), 5 deletions(-) create mode 100644 texar/losses/pg_losses_test.py diff --git a/texar/losses/__init__.py b/texar/losses/__init__.py index 1f4a52335..ae0ceb1fa 100644 --- a/texar/losses/__init__.py +++ b/texar/losses/__init__.py @@ -25,3 +25,4 @@ from texar.losses.entropy import * from texar.losses.mle_losses import * from texar.losses.adv_losses import * +from texar.losses.pg_losses import * diff --git a/texar/losses/losses_utils.py b/texar/losses/losses_utils.py index dea4bc10d..690f13570 100644 --- a/texar/losses/losses_utils.py +++ b/texar/losses/losses_utils.py @@ -108,10 +108,10 @@ def mask_and_reduce(sequence, raise ValueError("Only one of `average_across_remaining` and " "`sum_over_remaining` can be set.") if average_across_remaining: - for axis in range(2, rank): + for axis in sorted(list(range(2, rank)), reverse=True): sequence = torch.mean(sequence, dim=axis) elif sum_over_remaining: - for axis in range(2, rank): + for axis in sorted(list(range(2, rank)), reverse=True): sequence = torch.sum(sequence, dim=axis) sequence = reduce_batch_time(sequence, diff --git a/texar/losses/mle_losses_test.py b/texar/losses/mle_losses_test.py index 736af4085..0c5b71b09 100644 --- a/texar/losses/mle_losses_test.py +++ b/texar/losses/mle_losses_test.py @@ -86,7 +86,7 @@ def test_sequence_sparse_softmax_cross_entropy(self): self._labels, self._logits, self._sequence_length) def test_sequence_sigmoid_cross_entropy(self): - """Tests `texar.losses.test_sequence_sigmoid_cross_entropy`. + """Tests `texar.losses.sequence_sigmoid_cross_entropy`. """ self._test_sequence_loss( tx.losses.sequence_sigmoid_cross_entropy, diff --git a/texar/losses/pg_losses.py b/texar/losses/pg_losses.py index 0e6f1fcc4..7a2928b16 100644 --- a/texar/losses/pg_losses.py +++ b/texar/losses/pg_losses.py @@ -243,10 +243,10 @@ def pg_loss_with_log_probs(log_probs, raise ValueError("Only one of `average_across_remaining` and " "`sum_over_remaining` can be set.") if average_across_remaining: - for average_axis in range(1, rank): + for average_axis in sorted(list(range(1, rank)), reverse=True): losses = torch.mean(losses, dim=average_axis) elif sum_over_remaining: - for sum_axis in range(1, rank): + for sum_axis in sorted(list(range(1, rank)), reverse=True): losses = torch.sum(losses, dim=sum_axis) if not batched: diff --git a/texar/losses/pg_losses_test.py b/texar/losses/pg_losses_test.py new file mode 100644 index 000000000..1a2a468d4 --- /dev/null +++ b/texar/losses/pg_losses_test.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +# +""" +Unit tests for pg losses. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +# pylint: disable=invalid-name + +import unittest +import torch +import texar as tx + +from texar.utils.shapes import get_rank + + +class PGLossesTest(unittest.TestCase): + """Tests pg losses + """ + + def setUp(self): + self._batch_size = 64 + self._max_time = 16 + self._d1 = 32 + self._d2 = 32 + self._d3 = 32 + self._num_classes = 10 + self._actions_batch = torch.ones(self._batch_size, self._max_time, + self._d1, self._d2, self._d3, + dtype=torch.int64) + self._logits_batch = torch.rand(self._batch_size, self._max_time, + self._d1, self._d2, self._d3, + self._num_classes) + self._advantages_batch = torch.rand(self._batch_size, self._max_time, + self._d1, self._d2, self._d3) + self._actions_no_batch = torch.ones(self._max_time, self._d1, self._d2, + self._d3, dtype=torch.int64) + self._logits_no_batch = torch.rand(self._max_time, self._d1, self._d2, + self._d3, self._num_classes) + self._advantages_no_batch = torch.rand(self._max_time, self._d1, + self._d2, self._d3) + self._sequence_length = torch.randint(size=(self._batch_size,), + high=self._max_time) + + def _test_sequence_loss(self, loss_fn, actions, logits, advantages, batched, + sequence_length): + loss = loss_fn(actions, logits, advantages, batched=batched, + sequence_length=sequence_length) + rank = get_rank(loss) + self.assertEqual(rank, 0) + + loss = loss_fn(actions, logits, advantages, batched=batched, + sequence_length=sequence_length, + sum_over_timesteps=False) + rank = get_rank(loss) + self.assertEqual(rank, 1) + self.assertEqual(loss.shape, torch.Size([self._max_time])) + + loss = loss_fn(actions, logits, advantages, batched=batched, + sequence_length=sequence_length, + sum_over_timesteps=False, + average_across_timesteps=True, + average_across_batch=False) + rank = get_rank(loss) + if batched: + self.assertEqual(rank, 1) + self.assertEqual(loss.shape, torch.Size([self._batch_size])) + else: + self.assertEqual(rank, 0) + + loss = loss_fn(actions, logits, advantages, batched=batched, + sequence_length=sequence_length, + sum_over_timesteps=False, + average_across_batch=False) + rank = get_rank(loss) + if batched: + self.assertEqual(rank, 2) + self.assertEqual(loss.shape, + torch.Size([self._batch_size, self._max_time])) + else: + self.assertEqual(rank, 1) + self.assertEqual(loss.shape, + torch.Size([self._max_time])) + + sequence_length_time = torch.randint(size=(self._max_time,), + high=self._batch_size) + loss = loss_fn(actions, logits, advantages, batched=batched, + sequence_length=sequence_length_time, + sum_over_timesteps=False, + average_across_batch=False, + time_major=True) + if batched: + self.assertEqual(loss.shape, torch.Size([self._batch_size, + self._max_time])) + else: + self.assertEqual(loss.shape, torch.Size([self._max_time])) + + def test_pg_loss_with_logits(self): + """Tests `texar.losses.pg_loss_with_logits`. + """ + self._test_sequence_loss(tx.losses.pg_loss_with_logits, + self._actions_batch, + self._logits_batch, + self._advantages_batch, + True, + self._sequence_length) + + self._test_sequence_loss(tx.losses.pg_loss_with_logits, + self._actions_no_batch, + self._logits_no_batch, + self._advantages_no_batch, + False, + self._sequence_length) + + +if __name__ == "__main__": + unittest.main()