forked from asyml/texar-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request asyml#10 from ZhitingHu/losses
Losses
- Loading branch information
Showing
11 changed files
with
1,591 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright 2019 The Texar Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Modules of texar losses. | ||
""" | ||
|
||
# pylint: disable=wildcard-import | ||
|
||
from texar.losses.losses_utils import * | ||
from texar.losses.entropy import * | ||
from texar.losses.mle_losses import * | ||
from texar.losses.adv_losses import * | ||
from texar.losses.pg_losses import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright 2019 The Texar Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Adversarial losses. | ||
""" | ||
|
||
from typing import Callable, Tuple | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from texar.utils.types import MaybeTuple | ||
|
||
|
||
def binary_adversarial_losses( | ||
real_data: torch.Tensor, | ||
fake_data: torch.Tensor, | ||
discriminator_fn: Callable[[torch.Tensor], MaybeTuple[torch.Tensor]], | ||
mode: str = "max_real") -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Computes adversarial losses of real/fake binary discrimination game. | ||
.. role:: python(code) | ||
:language: python | ||
Args: | ||
real_data (Tensor or array): Real data of shape | ||
`[num_real_examples, ...]`. | ||
fake_data (Tensor or array): Fake data of shape | ||
`[num_fake_examples, ...]`. `num_real_examples` does not | ||
necessarily equal `num_fake_examples`. | ||
discriminator_fn: A callable takes data (e.g., :attr:`real_data` and | ||
:attr:`fake_data`) and returns the logits of being real. The | ||
signature of `discriminator_fn` must be: | ||
:python:`logits, ... = discriminator_fn(data)`. | ||
The return value of `discriminator_fn` can be the logits, or | ||
a tuple where the logits are the first element. | ||
mode (str): Mode of the generator loss. Either "max_real" or "min_fake". | ||
- **"max_real"** (default): minimizing the generator loss is to\ | ||
maximize the probability of fake data being classified as real. | ||
- **"min_fake"**: minimizing the generator loss is to minimize the\ | ||
probability of fake data being classified as fake. | ||
Returns: | ||
A tuple `(generator_loss, discriminator_loss)` each of which is | ||
a scalar Tensor, loss to be minimized. | ||
""" | ||
real_logits = discriminator_fn(real_data) | ||
if isinstance(real_logits, (list, tuple)): | ||
real_logits = real_logits[0] | ||
real_loss = F.binary_cross_entropy_with_logits( | ||
real_logits, torch.ones_like(real_logits)) | ||
|
||
fake_logits = discriminator_fn(fake_data) | ||
if isinstance(fake_logits, (list, tuple)): | ||
fake_logits = fake_logits[0] | ||
fake_loss = F.binary_cross_entropy_with_logits( | ||
fake_logits, torch.zeros_like(fake_logits)) | ||
|
||
d_loss = real_loss + fake_loss | ||
|
||
if mode == "min_fake": | ||
g_loss = - fake_loss | ||
elif mode == "max_real": | ||
g_loss = torch.nn.BCEWithLogitsLoss(reduction='mean') | ||
g_loss = g_loss(fake_logits, torch.ones_like(fake_logits)) | ||
else: | ||
raise ValueError("Unknown mode: %s. Only 'min_fake' and 'max_real' " | ||
"are allowed.") | ||
|
||
return g_loss, d_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
""" | ||
Unit tests for adv_losses. | ||
""" | ||
|
||
# pylint: disable=invalid-name | ||
|
||
import unittest | ||
|
||
import torch | ||
|
||
import texar as tx | ||
|
||
|
||
class AdvLossesTest(unittest.TestCase): | ||
"""Tests adversarial losses. | ||
""" | ||
|
||
def test_binary_adversarial_losses(self): | ||
"""Tests :meth:`~texar.losses.adv_losses.binary_adversarial_losse`. | ||
""" | ||
batch_size = 16 | ||
data_dim = 64 | ||
real_data = torch.zeros(size=(batch_size, data_dim), | ||
dtype=torch.float32) | ||
fake_data = torch.ones(size=(batch_size, data_dim), | ||
dtype=torch.float32) | ||
const_logits = torch.zeros(size=(batch_size,), dtype=torch.float32) | ||
# Use a dumb discriminator that always outputs logits=0. | ||
gen_loss, disc_loss = tx.losses.binary_adversarial_losses( | ||
real_data, fake_data, lambda x: const_logits) | ||
gen_loss_2, disc_loss_2 = tx.losses.binary_adversarial_losses( | ||
real_data, fake_data, lambda x: const_logits, mode="min_fake") | ||
|
||
self.assertAlmostEqual(gen_loss.item(), -gen_loss_2.item()) | ||
self.assertAlmostEqual(disc_loss.item(), disc_loss_2.item()) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
# Copyright 2019 The Texar Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Various entropies. | ||
""" | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from texar.losses.losses_utils import mask_and_reduce, reduce_dimensions | ||
from texar.utils.shapes import get_rank | ||
|
||
# pylint: disable=too-many-arguments | ||
|
||
__all__ = [ | ||
"entropy_with_logits", | ||
"sequence_entropy_with_logits" | ||
] | ||
|
||
|
||
def _get_entropy(logits: torch.Tensor) -> torch.Tensor: | ||
probs = F.softmax(logits, -1) + 1e-8 | ||
entropy = - probs * torch.log(probs) | ||
entropy = torch.sum(entropy, -1) | ||
return entropy | ||
|
||
|
||
def entropy_with_logits(logits: torch.Tensor, | ||
rank: Optional[int] = None, | ||
average_across_batch: bool = True, | ||
average_across_remaining: bool = False, | ||
sum_over_batch: bool = False, | ||
sum_over_remaining: bool = True) -> torch.Tensor: | ||
"""Shannon entropy given logits. | ||
Args: | ||
logits: Unscaled log probabilities of shape | ||
`[batch_size, d_2, ..., d_{rank-1}, distribution_dim]` | ||
and of dtype `float32` or `float64`. | ||
The rank of the tensor is optionally specified by the argument | ||
:attr:`rank`. | ||
The tensor is considered as having `[batch_size, .., d_{rank-1}]` | ||
elements, each of which has a distribution of length `d_rank` | ||
(i.e., `distribution_dim`). So the last dimension is always | ||
summed out to compute the entropy. | ||
rank (int, optional): The rank of :attr:`logits`. | ||
If `None` (default), `rank` is inferred automatically from | ||
`logits`. If the inference fails, `rank` is | ||
set to 2, i.e., assuming :attr:`logits` is of shape | ||
`[batch_size, distribution_dim]` | ||
average_across_batch (bool): If set, average the entropy across the | ||
batch dimension. Must not set `average_across_batch`' | ||
and `sum_over_batch` at the same time. | ||
average_across_remaining (bool): If set, average the entropy across the | ||
remaining dimensions. Must not set `average_across_remaining`' | ||
and `sum_over_remaining` at the same time. | ||
Used only when :attr:`logits` has rank >= 3. | ||
sum_over_batch (bool): If set, sum the entropy across the | ||
batch dimension. Must not set `average_across_batch` | ||
and `sum_over_batch` at the same time. | ||
sum_over_remaining (bool): If set, sum the entropy across the | ||
remaining dimension. Must not set `average_across_remaining` | ||
and `sum_over_remaining` at the same time. | ||
Used only when :attr:`logits` has rank >= 3. | ||
Returns: | ||
A Tensor containing the shannon entropy. The dimensionality of the | ||
Tensor depends on the configuration of reduction arguments. For | ||
example, if both batch and remaining dimensions are reduced (by | ||
either sum or average), the returned Tensor is a scalar Tensor. | ||
""" | ||
entropy = _get_entropy(logits) | ||
|
||
if rank is None: | ||
rank = get_rank(logits) | ||
if rank is None: | ||
rank = 2 | ||
rank -= 1 | ||
|
||
if average_across_batch and sum_over_batch: | ||
raise ValueError("Only one of `average_across_batch` and " | ||
"`sum_over_batch` can be set.") | ||
if average_across_remaining and sum_over_remaining: | ||
raise ValueError("Only one of `average_across_remaining` and " | ||
"`sum_over_remaining` can be set.") | ||
sum_axes, average_axes = [], [] | ||
if sum_over_batch: | ||
sum_axes.append(0) | ||
if average_across_batch: | ||
average_axes.append(0) | ||
if sum_over_remaining and rank >= 2: | ||
sum_axes += list(range(1, rank)) | ||
if average_across_remaining and rank >= 2: | ||
average_axes += list(range(1, rank)) | ||
|
||
entropy = reduce_dimensions( | ||
entropy, average_axes=average_axes, sum_axes=sum_axes | ||
) | ||
|
||
return entropy | ||
|
||
|
||
def sequence_entropy_with_logits( | ||
logits: torch.Tensor, | ||
rank: Optional[int] = None, | ||
sequence_length: Optional[torch.LongTensor] = None, | ||
average_across_batch: bool = True, | ||
average_across_timesteps: bool = False, | ||
average_across_remaining: bool = False, | ||
sum_over_batch: bool = False, | ||
sum_over_timesteps: bool = True, | ||
sum_over_remaining: bool = True, | ||
time_major: bool = False) -> torch.Tensor: | ||
"""Shannon entropy given logits. | ||
Args: | ||
logits: Unscaled log probabilities of shape | ||
`[batch_size, max_time, d_3, ..., d_{rank-1}, distribution_dim]` | ||
and of dtype `float32` or `float64`. | ||
The rank of the tensor is optionally specified by the argument | ||
:attr:`rank`. | ||
The tensor is considered as having `[batch_size, .., d_{rank-1}]` | ||
elements, each of which has a distribution of length `d_rank` | ||
(i.e., `distribution_dim`). So the last dimension is always | ||
summed out to compute the entropy. | ||
The batch and time dimensions are exchanged if :attr:`time_major` | ||
is `True`. | ||
rank (int, optional): The rank of :attr:`logits`. | ||
If `None` (default), `rank` is inferred automatically from | ||
`logits`. If the inference fails, `rank` is | ||
set to 3, i.e., assuming `logits` is of shape | ||
`[batch_size, max_time, distribution_dim]` | ||
sequence_length (optional): A Tensor of shape `[batch_size]`. | ||
Time steps beyond the respective sequence lengths are | ||
counted into the entropy. | ||
average_across_timesteps (bool): If set, average the entropy across | ||
the time dimension. Must not set `average_across_timesteps` | ||
and `sum_over_timesteps` at the same time. | ||
average_across_batch (bool): If set, average the entropy across the | ||
batch dimension. Must not set `average_across_batch`' | ||
and `sum_over_batch` at the same time. | ||
average_across_remaining (bool): If set, average the entropy across the | ||
remaining dimensions. Must not set `average_across_remaining`' | ||
and `sum_over_remaining` at the same time. | ||
Used only when :attr:`logits` has rank >= 4. | ||
sum_over_timesteps (bool): If set, sum the entropy across the | ||
time dimension. Must not set `average_across_timesteps` | ||
and `sum_over_timesteps` at the same time. | ||
sum_over_batch (bool): If set, sum the entropy across the | ||
batch dimension. Must not set `average_across_batch` | ||
and `sum_over_batch` at the same time. | ||
sum_over_remaining (bool): If set, sum the entropy across the | ||
remaining dimension. Must not set `average_across_remaining` | ||
and `sum_over_remaining` at the same time. | ||
Used only when :attr:`logits` has rank >= 4. | ||
time_major (bool): The shape format of the inputs. If `True`, | ||
:attr:`logits` must have shape `[max_time, batch_size, ...]`. | ||
If `False` (default), it must have shape | ||
`[batch_size, max_time, ...]`. | ||
Returns: | ||
A Tensor containing the shannon entropy. The dimensionality of the | ||
Tensor depends on the configuration of reduction arguments. For | ||
example, if batch, time, and remaining dimensions are all reduced (by | ||
either sum or average), the returned Tensor is a scalar Tensor. | ||
""" | ||
entropy = _get_entropy(logits) | ||
|
||
if rank is None: | ||
rank = get_rank(logits) | ||
if rank is None: | ||
rank = 3 | ||
rank -= 1 | ||
|
||
entropy = mask_and_reduce( | ||
entropy, | ||
sequence_length, | ||
rank=rank, | ||
average_across_batch=average_across_batch, | ||
average_across_timesteps=average_across_timesteps, | ||
average_across_remaining=average_across_remaining, | ||
sum_over_batch=sum_over_batch, | ||
sum_over_timesteps=sum_over_timesteps, | ||
sum_over_remaining=sum_over_remaining, | ||
time_major=time_major | ||
) | ||
|
||
return entropy |
Oops, something went wrong.