Skip to content

Commit

Permalink
fix bug in pg_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
gpengzhi committed May 13, 2019
1 parent c1a3b19 commit b970e2d
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 5 deletions.
1 change: 1 addition & 0 deletions texar/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from texar.losses.entropy import *
from texar.losses.mle_losses import *
from texar.losses.adv_losses import *
from texar.losses.pg_losses import *
4 changes: 2 additions & 2 deletions texar/losses/losses_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def mask_and_reduce(sequence,
raise ValueError("Only one of `average_across_remaining` and "
"`sum_over_remaining` can be set.")
if average_across_remaining:
for axis in range(2, rank):
for axis in sorted(list(range(2, rank)), reverse=True):
sequence = torch.mean(sequence, dim=axis)
elif sum_over_remaining:
for axis in range(2, rank):
for axis in sorted(list(range(2, rank)), reverse=True):
sequence = torch.sum(sequence, dim=axis)

sequence = reduce_batch_time(sequence,
Expand Down
2 changes: 1 addition & 1 deletion texar/losses/mle_losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_sequence_sparse_softmax_cross_entropy(self):
self._labels, self._logits, self._sequence_length)

def test_sequence_sigmoid_cross_entropy(self):
"""Tests `texar.losses.test_sequence_sigmoid_cross_entropy`.
"""Tests `texar.losses.sequence_sigmoid_cross_entropy`.
"""
self._test_sequence_loss(
tx.losses.sequence_sigmoid_cross_entropy,
Expand Down
4 changes: 2 additions & 2 deletions texar/losses/pg_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ def pg_loss_with_log_probs(log_probs,
raise ValueError("Only one of `average_across_remaining` and "
"`sum_over_remaining` can be set.")
if average_across_remaining:
for average_axis in range(1, rank):
for average_axis in sorted(list(range(1, rank)), reverse=True):
losses = torch.mean(losses, dim=average_axis)
elif sum_over_remaining:
for sum_axis in range(1, rank):
for sum_axis in sorted(list(range(1, rank)), reverse=True):
losses = torch.sum(losses, dim=sum_axis)

if not batched:
Expand Down
121 changes: 121 additions & 0 deletions texar/losses/pg_losses_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# -*- coding: utf-8 -*-
#
"""
Unit tests for pg losses.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

# pylint: disable=invalid-name

import unittest
import torch
import texar as tx

from texar.utils.shapes import get_rank


class PGLossesTest(unittest.TestCase):
"""Tests pg losses
"""

def setUp(self):
self._batch_size = 64
self._max_time = 16
self._d1 = 32
self._d2 = 32
self._d3 = 32
self._num_classes = 10
self._actions_batch = torch.ones(self._batch_size, self._max_time,
self._d1, self._d2, self._d3,
dtype=torch.int64)
self._logits_batch = torch.rand(self._batch_size, self._max_time,
self._d1, self._d2, self._d3,
self._num_classes)
self._advantages_batch = torch.rand(self._batch_size, self._max_time,
self._d1, self._d2, self._d3)
self._actions_no_batch = torch.ones(self._max_time, self._d1, self._d2,
self._d3, dtype=torch.int64)
self._logits_no_batch = torch.rand(self._max_time, self._d1, self._d2,
self._d3, self._num_classes)
self._advantages_no_batch = torch.rand(self._max_time, self._d1,
self._d2, self._d3)
self._sequence_length = torch.randint(size=(self._batch_size,),
high=self._max_time)

def _test_sequence_loss(self, loss_fn, actions, logits, advantages, batched,
sequence_length):
loss = loss_fn(actions, logits, advantages, batched=batched,
sequence_length=sequence_length)
rank = get_rank(loss)
self.assertEqual(rank, 0)

loss = loss_fn(actions, logits, advantages, batched=batched,
sequence_length=sequence_length,
sum_over_timesteps=False)
rank = get_rank(loss)
self.assertEqual(rank, 1)
self.assertEqual(loss.shape, torch.Size([self._max_time]))

loss = loss_fn(actions, logits, advantages, batched=batched,
sequence_length=sequence_length,
sum_over_timesteps=False,
average_across_timesteps=True,
average_across_batch=False)
rank = get_rank(loss)
if batched:
self.assertEqual(rank, 1)
self.assertEqual(loss.shape, torch.Size([self._batch_size]))
else:
self.assertEqual(rank, 0)

loss = loss_fn(actions, logits, advantages, batched=batched,
sequence_length=sequence_length,
sum_over_timesteps=False,
average_across_batch=False)
rank = get_rank(loss)
if batched:
self.assertEqual(rank, 2)
self.assertEqual(loss.shape,
torch.Size([self._batch_size, self._max_time]))
else:
self.assertEqual(rank, 1)
self.assertEqual(loss.shape,
torch.Size([self._max_time]))

sequence_length_time = torch.randint(size=(self._max_time,),
high=self._batch_size)
loss = loss_fn(actions, logits, advantages, batched=batched,
sequence_length=sequence_length_time,
sum_over_timesteps=False,
average_across_batch=False,
time_major=True)
if batched:
self.assertEqual(loss.shape, torch.Size([self._batch_size,
self._max_time]))
else:
self.assertEqual(loss.shape, torch.Size([self._max_time]))

def test_pg_loss_with_logits(self):
"""Tests `texar.losses.pg_loss_with_logits`.
"""
self._test_sequence_loss(tx.losses.pg_loss_with_logits,
self._actions_batch,
self._logits_batch,
self._advantages_batch,
True,
self._sequence_length)

self._test_sequence_loss(tx.losses.pg_loss_with_logits,
self._actions_no_batch,
self._logits_no_batch,
self._advantages_no_batch,
False,
self._sequence_length)


if __name__ == "__main__":
unittest.main()

0 comments on commit b970e2d

Please sign in to comment.