Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
gpengzhi committed May 3, 2019
1 parent 91106b2 commit e15cf33
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions texar/losses/mle_losses_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
#
"""
Unit tests for mle losses.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

# pylint: disable=invalid-name

import unittest
import torch
import torch.nn.functional as F

import texar as tx


class MLELossesTest(unittest.TestCase):
"""Tests mle losses.
"""

def setUp(self):
self._batch_size = 64
self._max_time = 16
self._num_classes = 100
self._labels = torch.ones(self._batch_size, self._max_time,
dtype=torch.int32)
one_hot_labels = F.one_hot(self._labels, self._num_classes)
self._one_hot_labels = torch.reshape(
one_hot_labels, [self._batch_size, self._max_time, -1])
self._logits = torch.rand(self._batch_size, self._max_time,
self._num_classes)
self._sequence_length = torch.rand(self._batch_size) * self._max_time

def _test_sequence_loss(self, loss_fn, labels, logits, sequence_length):
loss = loss_fn(labels, logits, sequence_length)
rank = len(loss.shape)
self.assertEqual(rank, 0)

loss = loss_fn(labels, logits, sequence_length,
sum_over_timesteps=False)
rank = len(loss.shape)
self.assertEqual(rank, 1)
self.assertEqual(loss.shape, torch.Size([self._max_time]))

loss = loss_fn(
labels, logits, sequence_length, sum_over_timesteps=False,
average_across_timesteps=True, average_across_batch=False)
rank = len(loss.shape)
self.assertEqual(rank, 1)
self.assertEqual(loss.shape, torch.Size([self._batch_size]))

loss = loss_fn(
labels, logits, sequence_length, sum_over_timesteps=False,
average_across_batch=False)
rank = len(loss.shape)
self.assertEqual(rank, 2)
self.assertEqual(loss.shape, torch.Size([self._batch_size,
self._max_time]))

sequence_length_time = torch.rand(self._max_time) * self._max_time
loss = loss_fn(
labels, logits, sequence_length_time, sum_over_timesteps=False,
average_across_batch=False, time_major=True)
self.assertEqual(loss.shape, torch.Size([self._batch_size,
self._max_time]))

def test_sequence_softmax_cross_entropy(self):
"""Tests `sequence_softmax_cross_entropy`
"""
self._test_sequence_loss(
tx.losses.sequence_softmax_cross_entropy,
self._one_hot_labels, self._logits, self._sequence_length)


if __name__ == "__main__":
unittest.main()

0 comments on commit e15cf33

Please sign in to comment.