From e8e9e1088374c93471d9ed3bc1342687fb363e3e Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Tue, 30 Apr 2019 11:41:14 -0400 Subject: [PATCH 01/24] add reduce_dimensions --- texar/losses/__init__.py | 22 +++++++++ texar/losses/losses_utils.py | 72 ++++++++++++++++++++++++++++++ texar/losses/mle_losses.py | 86 ++++++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+) create mode 100644 texar/losses/__init__.py create mode 100644 texar/losses/losses_utils.py create mode 100644 texar/losses/mle_losses.py diff --git a/texar/losses/__init__.py b/texar/losses/__init__.py new file mode 100644 index 000000000..26b591f9b --- /dev/null +++ b/texar/losses/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Modules of texar losses. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import diff --git a/texar/losses/losses_utils.py b/texar/losses/losses_utils.py new file mode 100644 index 000000000..08d22b3b1 --- /dev/null +++ b/texar/losses/losses_utils.py @@ -0,0 +1,72 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Various utilities for losses. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import torch + + + + +def reduce_dimensions(tensor, average_axes=None, sum_axes=None, keepdims=None): + """Average or sum over dimensions of :attr:`tensor`. + + :attr:`average_axes` and :attr:`sum_axes` must be mutually exclusive. That + is, elements in `average_axes` must not be contained in + `sum_axes`, and vice versa. + + Args: + tensor: A tensor to reduce. + average_axes (optional): A (list of) `int` that indicates the + dimensions to reduce by taking average. + sum_axes (optional): A (list of) `int` that indicates the + dimensions to reduce by taking sum. + keepdims (optional): If `True`, retains reduced dimensions with + length 1. + """ + reduced_axes = set() + if average_axes is not None: + if not isinstance(average_axes, (list, tuple)): + average_axes = [average_axes] + if len(average_axes) > 0: + for average_axis in average_axes: + tensor = torch.mean(tensor, dim=average_axis, keepdim=True) + reduced_axes.update(average_axes) + + if sum_axes is not None: + if not isinstance(sum_axes, (list, tuple)): + sum_axes = [sum_axes] + if len(sum_axes) > 0: + for sum_axis in sum_axes: + tensor = torch.sum(tensor, dim=sum_axis, keepdim=True) + reduced_axes.update(sum_axes) + + if average_axes is not None: + if len(reduced_axes) != len(average_axes) + len(sum_axes): + raise ValueError('`average_axes` and `sum_axes` must not ' + 'have overlapped elements.') + if not keepdims: + tensor = torch.squeeze(tensor, dim=list(reduced_axes)) + + return tensor + + + + diff --git a/texar/losses/mle_losses.py b/texar/losses/mle_losses.py new file mode 100644 index 000000000..9c71ddfe0 --- /dev/null +++ b/texar/losses/mle_losses.py @@ -0,0 +1,86 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Various losses +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + + + +def sequence_softmax_cross_entropy(labels, + logits, + sequence_length, + average_across_batch=True, + average_across_timesteps=False, + sum_over_batch=False, + sum_over_timesteps=True, + time_major=False, + stop_gradient_to_label=False, + name=None): + """Computes softmax cross entropy for each time step of sequence + predictions. + + Args: + labels: Target class distributions. + + - If :attr:`time_major` is `False` (default), this must be a\ + Tensor of shape `[batch_size, max_time, num_classes]`. + + - If `time_major` is `True`, this must be a Tensor of shape\ + `[max_time, batch_size, num_classes]`. + + Each row of `labels` should be a valid probability + distribution, otherwise, the computation of the gradient will be + incorrect. + logits: Unscaled log probabilities. This must have the shape of + `[max_time, batch_size, num_classes]` or + `[batch_size, max_time, num_classes]` according to + the value of `time_major`. + sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond + the respective sequence lengths will have zero losses. + average_across_timesteps (bool): If set, average the loss across + the time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + average_across_batch (bool): If set, average the loss across the + batch dimension. Must not set `average_across_batch`' + and `sum_over_batch` at the same time. + sum_over_timesteps (bool): If set, sum the loss across the + time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + sum_over_batch (bool): If set, sum the loss across the + batch dimension. Must not set `average_across_batch` + and `sum_over_batch` at the same time. + time_major (bool): The shape format of the inputs. If `True`, + :attr:`labels` and :attr:`logits` must have shape + `[max_time, batch_size, ...]`. If `False` + (default), they must have shape `[batch_size, max_time, ...]`. + stop_gradient_to_label (bool): If set, gradient propagation to + :attr:`labels` will be disabled. + name (str, optional): A name for the operation. + + Returns: + A Tensor containing the loss, of rank 0, 1, or 2 depending on the + arguments :attr:`{average_across}/{sum_over}_{timesteps}/{batch}`. + For example: + + - If :attr:`sum_over_timesteps` and :attr:`average_across_batch` \ + are `True` (default), the return Tensor is of rank 0. + + - If :attr:`average_across_batch` is `True` and other arguments are \ + `False`, the return Tensor is of shape `[max_time]`. + """ \ No newline at end of file From 770e64fe4063ff2af23e1efc804f0c8ddd0dd549 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Tue, 30 Apr 2019 13:19:32 -0400 Subject: [PATCH 02/24] update losses_utils.py --- texar/losses/losses_utils.py | 104 +++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/texar/losses/losses_utils.py b/texar/losses/losses_utils.py index 08d22b3b1..5caf0564d 100644 --- a/texar/losses/losses_utils.py +++ b/texar/losses/losses_utils.py @@ -23,6 +23,110 @@ import torch +# pylint: disable=invalid-name, not-context-manager, protected-access, +# pylint: disable=too-many-arguments + +__all__ = [ + "mask_and_reduce", + "reduce_batch_time", + "reduce_dimensions" +] + + +def mask_and_reduce(sequence, + sequence_length, + rank=2, + average_across_batch=True, + average_across_timesteps=False, + average_across_remaining=False, + sum_over_batch=False, + sum_over_timesteps=True, + sum_over_remaining=True, + dtype=None, + time_major=False): + """Masks out sequence entries that are beyond the respective sequence + lengths, and reduces (average or sum) away dimensions. + + This is a combination of :func:`~texar.utils.shapes.mask_sequences` + and :func:`~texar.losses.losses_utils.reduce_batch_time`. + + Args: + sequence: A Tensor of sequence values. + If `time_major=False` (default), this must be a Tensor of shape + `[batch_size, max_time, d_2, ..., d_rank]`, where the rank of + the Tensor is specified with :attr:`rank`. + The batch and time dimensions are exchanged if `time_major` is True. + sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond + the respective sequence lengths will be made zero. If `None`, + not masking is performed. + rank (int): The rank of :attr:`sequence`. Must be >= 2. Default is 2, + i.e., `sequence` is a 2D Tensor consisting of batch and time + dimensions. + average_across_timesteps (bool): If set, average the sequence across + the time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + average_across_batch (bool): If set, average the sequence across the + batch dimension. Must not set `average_across_batch`' + and `sum_over_batch` at the same time. + average_across_remaining (bool): If set, average the sequence across the + remaining dimensions. Must not set `average_across_remaining`' + and `sum_over_remaining` at the same time. + sum_over_timesteps (bool): If set, sum the loss across the + time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + sum_over_batch (bool): If set, sum the loss across the + batch dimension. Must not set `average_across_batch` + and `sum_over_batch` at the same time. + sum_over_remaining (bool): If set, sum the loss across the + remaining dimension. Must not set `average_across_remaining` + and `sum_over_remaining` at the same time. + time_major (bool): The shape format of the inputs. If `True`, + :attr:`sequence` must have shape `[max_time, batch_size, ...]`. + If `False` (default), `sequence` must have + shape `[batch_size, max_time, ...]`. + dtype (dtype): Type of :attr:`sequence`. If `None`, infer from + :attr:`sequence` automatically. + + Returns + A Tensor containing the masked and reduced sequence. + """ + + + + +def reduce_batch_time(sequence, + sequence_length, + average_across_batch=True, + average_across_timesteps=False, + sum_over_batch=False, + sum_over_timesteps=True): + """Average or sum over the respective dimensions of :attr:`sequence`, which + is of shape `[batch_size, max_time, ...]`. + + Assumes :attr:`sequence` has been properly masked according to + :attr:`sequence_length`. + """ + if average_across_timesteps and sum_over_timesteps: + raise ValueError("Only one of `average_across_timesteps` and " + "`sum_over_timesteps` can be set.") + if average_across_batch and sum_over_batch: + raise ValueError("Only one of `average_across_batch` and " + "`sum_over_batch` can be set.") + + if sum_over_timesteps: + sequence = torch.sum(sequence, dim=1) + elif average_across_timesteps: + if sequence_length is None: + sequence = torch.mean(sequence, dim=1) + else: + sequence = torch.sum(sequence, dim=1)/sequence_length.float() + + if sum_over_batch: + sequence = torch.sum(sequence, dim=0) + elif average_across_batch: + sequence = torch.mean(sequence, dim=0) + + return sequence def reduce_dimensions(tensor, average_axes=None, sum_axes=None, keepdims=None): From 054c08ba3e6fcf5b0575f6dde14eea0b83df5417 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Tue, 30 Apr 2019 13:35:20 -0400 Subject: [PATCH 03/24] update losses_utils.py --- texar/losses/losses_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/texar/losses/losses_utils.py b/texar/losses/losses_utils.py index 5caf0564d..ee3085789 100644 --- a/texar/losses/losses_utils.py +++ b/texar/losses/losses_utils.py @@ -22,6 +22,7 @@ import torch +from texar.utils.shapes import mask_sequences # pylint: disable=invalid-name, not-context-manager, protected-access, # pylint: disable=too-many-arguments From 6b513d97138375e05b65966e15d7830a6037932b Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Tue, 30 Apr 2019 15:24:37 -0400 Subject: [PATCH 04/24] update reduce_batch_time --- texar/losses/losses_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/texar/losses/losses_utils.py b/texar/losses/losses_utils.py index ee3085789..666207e6b 100644 --- a/texar/losses/losses_utils.py +++ b/texar/losses/losses_utils.py @@ -120,7 +120,8 @@ def reduce_batch_time(sequence, if sequence_length is None: sequence = torch.mean(sequence, dim=1) else: - sequence = torch.sum(sequence, dim=1)/sequence_length.float() + sequence = torch.sum(sequence, dim=1).float() / \ + sequence_length.float() if sum_over_batch: sequence = torch.sum(sequence, dim=0) From 0cdac6d8084a088261c23875fb3fe21a13af117c Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Tue, 30 Apr 2019 16:11:38 -0400 Subject: [PATCH 05/24] update losses_utils.py --- texar/losses/losses_utils.py | 42 ++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/texar/losses/losses_utils.py b/texar/losses/losses_utils.py index 666207e6b..079e9e437 100644 --- a/texar/losses/losses_utils.py +++ b/texar/losses/losses_utils.py @@ -22,7 +22,7 @@ import torch -from texar.utils.shapes import mask_sequences +from texar.utils.shapes import transpose_batch_time, mask_sequences # pylint: disable=invalid-name, not-context-manager, protected-access, # pylint: disable=too-many-arguments @@ -91,8 +91,42 @@ def mask_and_reduce(sequence, Returns A Tensor containing the masked and reduced sequence. """ + if rank < 2: + raise ValueError('`rank` must be >= 2.') + + if time_major: + sequence = transpose_batch_time(sequence) + + if sequence_length is not None: + sequence = mask_sequences(sequence, + sequence_length, + dtype=dtype, + time_major=False) + + if rank > 2: + if average_across_remaining and sum_over_remaining: + raise ValueError("Only one of `average_across_remaining` and " + "`sum_over_remaining` can be set.") + if average_across_remaining: + for axis in range(2, rank): + sequence = torch.mean(sequence, dim=axis) + elif sum_over_remaining: + for axis in range(2, rank): + sequence = torch.sum(sequence, dim=axis) + + sequence = reduce_batch_time(sequence, + sequence_length, + average_across_batch, + average_across_timesteps, + sum_over_batch, + sum_over_timesteps) + + reduce_time = average_across_timesteps or sum_over_timesteps + reduce_batch = average_across_batch or sum_over_batch + if not reduce_time and not reduce_batch and time_major: + sequence = transpose_batch_time(sequence) - + return sequence def reduce_batch_time(sequence, @@ -172,7 +206,3 @@ def reduce_dimensions(tensor, average_axes=None, sum_axes=None, keepdims=None): tensor = torch.squeeze(tensor, dim=list(reduced_axes)) return tensor - - - - From 8df6b90e11125e49b0356c4637103db038bcc688 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Wed, 1 May 2019 12:02:13 -0400 Subject: [PATCH 06/24] update mle_losses --- texar/losses/mle_losses.py | 140 ++++++++++++++++++++++++++++++++++++- 1 file changed, 139 insertions(+), 1 deletion(-) diff --git a/texar/losses/mle_losses.py b/texar/losses/mle_losses.py index 9c71ddfe0..1559ba24c 100644 --- a/texar/losses/mle_losses.py +++ b/texar/losses/mle_losses.py @@ -19,7 +19,22 @@ from __future__ import division from __future__ import print_function +import torch +import torch.nn.functional as F +from texar.losses.losses_utils import mask_and_reduce, reduce_dimensions +from texar.utils import shapes + +# pylint: disable=invalid-name, not-context-manager, protected-access, +# pylint: disable=too-many-arguments + +__all__ = [ + "sequence_softmax_cross_entropy", + "sequence_sparse_softmax_cross_entropy", + "sequence_sigmoid_cross_entropy", + "binary_sigmoid_cross_entropy", + "binary_sigmoid_cross_entropy_with_clas" +] def sequence_softmax_cross_entropy(labels, @@ -83,4 +98,127 @@ def sequence_softmax_cross_entropy(labels, - If :attr:`average_across_batch` is `True` and other arguments are \ `False`, the return Tensor is of shape `[max_time]`. - """ \ No newline at end of file + """ + if stop_gradient_to_label: + labels = labels.detach() + + losses = torch.sum(- labels * F.log_softmax(logits, -1), -1) + + losses = mask_and_reduce(losses, + sequence_length, + rank=2, + average_across_batch=average_across_batch, + average_across_timesteps=average_across_timesteps, + sum_over_batch=sum_over_batch, + sum_over_timesteps=sum_over_timesteps, + time_major=time_major) + return losses + + +def sequence_sparse_softmax_cross_entropy(labels, + logits, + sequence_length, + average_across_batch=True, + average_across_timesteps=False, + sum_over_batch=False, + sum_over_timesteps=True, + time_major=False, + name=None): + """Computes sparse softmax cross entropy for each time step of sequence + predictions. + + Args: + labels: Target class indexes. I.e., classes are mutually exclusive + (each entry is in exactly one class). + + - If :attr:`time_major` is `False` (default), this must be\ + a Tensor of shape `[batch_size, max_time]`. + + - If `time_major` is `True`, this must be a Tensor of shape\ + `[max_time, batch_size].` + logits: Unscaled log probabilities. This must have the shape of + `[max_time, batch_size, num_classes]` or + `[batch_size, max_time, num_classes]` according to + the value of `time_major`. + sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond + the respective sequence lengths will have zero losses. + average_across_timesteps (bool): If set, average the loss across + the time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + average_across_batch (bool): If set, average the loss across the + batch dimension. Must not set `average_across_batch`' + and `sum_over_batch` at the same time. + sum_over_timesteps (bool): If set, sum the loss across the + time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + sum_over_batch (bool): If set, sum the loss across the + batch dimension. Must not set `average_across_batch` + and `sum_over_batch` at the same time. + time_major (bool): The shape format of the inputs. If `True`, + :attr:`labels` and :attr:`logits` must have shape + `[max_time, batch_size, ...]`. If `False` + (default), they must have shape `[batch_size, max_time, ...]`. + name (str, optional): A name for the operation. + + Returns: + A Tensor containing the loss, of rank 0, 1, or 2 depending on the + arguments :attr:`{average_across}/{sum_over}_{timesteps}/{batch}`. + For example: + + - If :attr:`sum_over_timesteps` and :attr:`average_across_batch` \ + are `True` (default), the return Tensor is of rank 0. + + - If :attr:`average_across_batch` is `True` and other arguments are \ + `False`, the return Tensor is of shape `[max_time]`. + + Example: + + .. code-block:: python + + embedder = WordEmbedder(vocab_size=data.vocab.size) + decoder = BasicRNNDecoder(vocab_size=data.vocab.size) + outputs, _, _ = decoder( + decoding_strategy='train_greedy', + inputs=embedder(data_batch['text_ids']), + sequence_length=data_batch['length']-1) + + loss = sequence_sparse_softmax_cross_entropy( + labels=data_batch['text_ids'][:, 1:], + logits=outputs.logits, + sequence_length=data_batch['length']-1) + + """ + losses = F.nll_loss(F.log_softmax(logits, dim=1), labels) + + losses = mask_and_reduce(losses, + sequence_length, + rank=2, + average_across_batch=average_across_batch, + average_across_timesteps=average_across_timesteps, + sum_over_batch=sum_over_batch, + sum_over_timesteps=sum_over_timesteps, + time_major=time_major) + return losses + + + + + + + + + + + + + + + + + + + + + + + From d29551bf780b316944cd46ae7e9dcc04c5e79865 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Wed, 1 May 2019 13:04:14 -0400 Subject: [PATCH 07/24] update mle_losses --- texar/losses/mle_losses.py | 72 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/texar/losses/mle_losses.py b/texar/losses/mle_losses.py index 1559ba24c..e56e3ad1b 100644 --- a/texar/losses/mle_losses.py +++ b/texar/losses/mle_losses.py @@ -201,7 +201,79 @@ def sequence_sparse_softmax_cross_entropy(labels, return losses +def sequence_sigmoid_cross_entropy(labels, + logits, + sequence_length, + average_across_batch=True, + average_across_timesteps=False, + average_across_classes=True, + sum_over_batch=False, + sum_over_timesteps=True, + sum_over_classes=False, + time_major=False, + stop_gradient_to_label=False, + name=None): + """Computes sigmoid cross entropy for each time step of sequence + predictions. + Args: + labels: Target class distributions. + + - If :attr:`time_major` is `False` (default), this must be a\ + Tensor of shape `[batch_size, max_time(, num_classes)]`. + + - If `time_major` is `True`, this must be a Tensor of shape\ + `[max_time, batch_size(, num_classes)]`. + + Each row of `labels` should be a valid probability + distribution, otherwise, the computation of the gradient will be + incorrect. + logits: Unscaled log probabilities having the same shape as with + :attr:`labels`. + sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond + the respective sequence lengths will have zero losses. + average_across_timesteps (bool): If set, average the loss across + the time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + average_across_batch (bool): If set, average the loss across the + batch dimension. Must not set `average_across_batch`' + and `sum_over_batch` at the same time. + average_across_classes (bool): If set, average the loss across the + class dimension (if exists). Must not set + `average_across_classes`' and `sum_over_classes` at + the same time. Ignored if :attr:`logits` is a 2D Tensor. + sum_over_timesteps (bool): If set, sum the loss across the + time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + sum_over_batch (bool): If set, sum the loss across the + batch dimension. Must not set `average_across_batch` + and `sum_over_batch` at the same time. + sum_over_classes (bool): If set, sum the loss across the + class dimension. Must not set `average_across_classes` + and `sum_over_classes` at the same time. Ignored if + :attr:`logits` is a 2D Tensor. + time_major (bool): The shape format of the inputs. If `True`, + :attr:`labels` and :attr:`logits` must have shape + `[max_time, batch_size, ...]`. If `False` + (default), they must have shape `[batch_size, max_time, ...]`. + stop_gradient_to_label (bool): If set, gradient propagation to + :attr:`labels` will be disabled. + name (str, optional): A name for the operation. + + Returns: + A Tensor containing the loss, of rank 0, 1, or 2 depending on the + arguments + :attr:`{average_across}/{sum_over}_{timesteps}/{batch}/{classes}`. + For example, if the class dimension does not exist, and + + - If :attr:`sum_over_timesteps` and :attr:`average_across_batch` \ + are `True` (default), the return Tensor is of rank 0. + + - If :attr:`average_across_batch` is `True` and other arguments are \ + `False`, the return Tensor is of shape `[max_time]`. + """ + if stop_gradient_to_label: + labels = labels.detach() From 91106b2cf697330cb29fd35ac68c9b228dd3d8e0 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Wed, 1 May 2019 14:37:39 -0400 Subject: [PATCH 08/24] update mle_losses --- texar/losses/mle_losses.py | 159 +++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) diff --git a/texar/losses/mle_losses.py b/texar/losses/mle_losses.py index e56e3ad1b..9289610e0 100644 --- a/texar/losses/mle_losses.py +++ b/texar/losses/mle_losses.py @@ -274,23 +274,182 @@ class dimension. Must not set `average_across_classes` """ if stop_gradient_to_label: labels = labels.detach() + losses = torch.nn.BCEWithLogitsLoss(reduction=None) + losses = losses(logits, labels) + rank = shapes.get_rank(logits) or shapes.get_rank(labels) + if rank is None: + raise ValueError( + 'Cannot determine the rank of `logits` or `labels`.') + losses = mask_and_reduce(losses, + sequence_length, + rank=rank, + average_across_batch=average_across_batch, + average_across_timesteps=average_across_timesteps, + average_across_remaining=average_across_classes, + sum_over_batch=sum_over_batch, + sum_over_timesteps=sum_over_timesteps, + sum_over_remaining=sum_over_classes, + time_major=time_major) + + return losses + + +def binary_sigmoid_cross_entropy(pos_logits=None, + neg_logits=None, + average_across_batch=True, + average_across_classes=True, + sum_over_batch=False, + sum_over_classes=False, + return_pos_neg_losses=False, + name=None): + """Computes sigmoid cross entropy of binary predictions. + + Args: + pos_logits: The logits of predicting positive on positive data. A + tensor of shape `[batch_size(, num_classes)]`. + neg_logits: The logits of predicting positive on negative data. A + tensor of shape `[batch_size(, num_classes)]`. + average_across_batch (bool): If set, average the loss across the + batch dimension. Must not set `average_across_batch`' + and `sum_over_batch` at the same time. + average_across_classes (bool): If set, average the loss across the + class dimension (if exists). Must not set + `average_across_classes`' and `sum_over_classes` at + the same time. Ignored if :attr:`logits` is a 1D Tensor. + sum_over_batch (bool): If set, sum the loss across the + batch dimension. Must not set `average_across_batch` + and `sum_over_batch` at the same time. + sum_over_classes (bool): If set, sum the loss across the + class dimension. Must not set `average_across_classes` + and `sum_over_classes` at the same time. Ignored if + :attr:`logits` is a 2D Tensor. + return_pos_neg_losses (bool): If set, additionally returns the losses + on :attr:`pos_logits` and :attr:`neg_logits`, respectively. + name (str, optional): A name for the operation. + Returns: + By default, a Tensor containing the loss, of rank 0, 1, or 2 depending + on the arguments :attr:`{average_across}/{sum_over}_{batch}/{classes}`. + For example: + - If :attr:`sum_over_batch` and :attr:`average_across_classes` \ + are `True` (default), the return Tensor is of rank 0. + - If arguments are `False`, the return Tensor is of shape \ + `[batch_size(, num_classes)]`. + If :attr:`return_pos_neg_losses` is `True`, returns a tuple + `(loss, pos_loss, neg_loss)`, where `loss` is the loss above; + `pos_loss` is the loss on `pos_logits` only; and + `neg_loss` is the loss on `neg_logits` only. They have + `loss = pos_loss + neg_loss`. + """ + average_axes, sum_axes = [], [] + average_axes += [0] if average_across_batch else [] + average_axes += [1] if average_across_classes else [] + sum_axes += [0] if sum_over_batch else [] + sum_axes += [1] if sum_over_classes else [] + pos_loss = 0 + if pos_logits is not None: + pos_loss = torch.nn.BCEWithLogitsLoss(reduction=None) + pos_loss = pos_loss(pos_logits, torch.ones_like(pos_logits)) + pos_loss = reduce_dimensions(pos_loss, average_axes, sum_axes) + neg_loss = 0 + if neg_logits is not None: + neg_loss = torch.nn.BCEWithLogitsLoss(reduction=None) + neg_loss = neg_loss(neg_logits, torch.zeros_like(neg_logits)) + neg_loss = reduce_dimensions(neg_loss, average_axes, sum_axes) + loss = pos_loss + neg_loss + if return_pos_neg_losses: + return loss, pos_loss, neg_loss + else: + return loss +def binary_sigmoid_cross_entropy_with_clas(clas_fn, + pos_inputs=None, + neg_inputs=None, + average_across_batch=True, + average_across_classes=True, + sum_over_batch=False, + sum_over_classes=False, + return_pos_neg_losses=False, + name=None): + """Computes sigmoid cross entropy of binary classifier. + .. role:: python(code) + :language: python + Args: + clas_fn: A callable takes data (e.g., :attr:`pos_inputs` and + :attr:`fake_inputs`) and returns the logits of being positive. The + signature of `clas_fn` must be: + :python:`logits (, ...) = clas_fn(inputs)`. + The return value of `clas_fn` can be the logits, or + a tuple where the logits are the first element. + pos_inputs: The positive data fed into `clas_fn`. + neg_inputs: The negative data fed into `clas_fn`. + average_across_batch (bool): If set, average the loss across the + batch dimension. Must not set `average_across_batch`' + and `sum_over_batch` at the same time. + average_across_classes (bool): If set, average the loss across the + class dimension (if exists). Must not set + `average_across_classes`' and `sum_over_classes` at + the same time. Ignored if :attr:`logits` is a 1D Tensor. + sum_over_batch (bool): If set, sum the loss across the + batch dimension. Must not set `average_across_batch` + and `sum_over_batch` at the same time. + sum_over_classes (bool): If set, sum the loss across the + class dimension. Must not set `average_across_classes` + and `sum_over_classes` at the same time. Ignored if + :attr:`logits` is a 2D Tensor. + return_pos_neg_losses (bool): If set, additionally returns the losses + on :attr:`pos_logits` and :attr:`neg_logits`, respectively. + name (str, optional): A name for the operation. + Returns: + By default, a Tensor containing the loss, of rank 0, 1, or 2 depending + on the arguments :attr:`{average_across}/{sum_over}_{batch}/{classes}`. + For example: + - If :attr:`sum_over_batch` and :attr:`average_across_classes` \ + are `True` (default), the return Tensor is of rank 0. + - If arguments are `False`, the return Tensor is of shape \ + `[batch_size(, num_classes)]`. + If :attr:`return_pos_neg_losses`=`True`, returns a tuple + `(loss, pos_loss, neg_loss)`, where `loss` is the loss above; + `pos_loss` is the loss on `pos_logits` only; and + `neg_loss` is the loss on `neg_logits` only. They have + `loss = pos_loss + neg_loss`. + """ + pos_logits = None + if pos_inputs is not None: + pos_logits = clas_fn(pos_inputs) + if isinstance(pos_logits, (list, tuple)): + pos_logits = pos_logits[0] + + neg_logits = None + if neg_inputs is not None: + neg_logits = clas_fn(neg_inputs) + if isinstance(neg_logits, (list, tuple)): + neg_logits = neg_logits[0] + + return binary_sigmoid_cross_entropy( + pos_logits=pos_logits, + neg_logits=neg_logits, + average_across_batch=average_across_batch, + average_across_classes=average_across_classes, + sum_over_batch=sum_over_batch, + sum_over_classes=sum_over_classes, + return_pos_neg_losses=return_pos_neg_losses, + name=name) From e15cf3346659314452673692edc3c7bb0c4410b2 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Fri, 3 May 2019 16:06:44 -0400 Subject: [PATCH 09/24] update --- texar/losses/mle_losses_test.py | 80 +++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 texar/losses/mle_losses_test.py diff --git a/texar/losses/mle_losses_test.py b/texar/losses/mle_losses_test.py new file mode 100644 index 000000000..04826b5a3 --- /dev/null +++ b/texar/losses/mle_losses_test.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# +""" +Unit tests for mle losses. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +# pylint: disable=invalid-name + +import unittest +import torch +import torch.nn.functional as F + +import texar as tx + + +class MLELossesTest(unittest.TestCase): + """Tests mle losses. + """ + + def setUp(self): + self._batch_size = 64 + self._max_time = 16 + self._num_classes = 100 + self._labels = torch.ones(self._batch_size, self._max_time, + dtype=torch.int32) + one_hot_labels = F.one_hot(self._labels, self._num_classes) + self._one_hot_labels = torch.reshape( + one_hot_labels, [self._batch_size, self._max_time, -1]) + self._logits = torch.rand(self._batch_size, self._max_time, + self._num_classes) + self._sequence_length = torch.rand(self._batch_size) * self._max_time + + def _test_sequence_loss(self, loss_fn, labels, logits, sequence_length): + loss = loss_fn(labels, logits, sequence_length) + rank = len(loss.shape) + self.assertEqual(rank, 0) + + loss = loss_fn(labels, logits, sequence_length, + sum_over_timesteps=False) + rank = len(loss.shape) + self.assertEqual(rank, 1) + self.assertEqual(loss.shape, torch.Size([self._max_time])) + + loss = loss_fn( + labels, logits, sequence_length, sum_over_timesteps=False, + average_across_timesteps=True, average_across_batch=False) + rank = len(loss.shape) + self.assertEqual(rank, 1) + self.assertEqual(loss.shape, torch.Size([self._batch_size])) + + loss = loss_fn( + labels, logits, sequence_length, sum_over_timesteps=False, + average_across_batch=False) + rank = len(loss.shape) + self.assertEqual(rank, 2) + self.assertEqual(loss.shape, torch.Size([self._batch_size, + self._max_time])) + + sequence_length_time = torch.rand(self._max_time) * self._max_time + loss = loss_fn( + labels, logits, sequence_length_time, sum_over_timesteps=False, + average_across_batch=False, time_major=True) + self.assertEqual(loss.shape, torch.Size([self._batch_size, + self._max_time])) + + def test_sequence_softmax_cross_entropy(self): + """Tests `sequence_softmax_cross_entropy` + """ + self._test_sequence_loss( + tx.losses.sequence_softmax_cross_entropy, + self._one_hot_labels, self._logits, self._sequence_length) + + +if __name__ == "__main__": + unittest.main() From 1cf83341ea37b09fbb458de9381bbbbc2e3dfbdd Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Mon, 6 May 2019 13:36:13 -0400 Subject: [PATCH 10/24] add more tests in mle_losses_test.py --- texar/losses/mle_losses_test.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/texar/losses/mle_losses_test.py b/texar/losses/mle_losses_test.py index 04826b5a3..b5d32c007 100644 --- a/texar/losses/mle_losses_test.py +++ b/texar/losses/mle_losses_test.py @@ -75,6 +75,33 @@ def test_sequence_softmax_cross_entropy(self): tx.losses.sequence_softmax_cross_entropy, self._one_hot_labels, self._logits, self._sequence_length) + def test_sequence_sparse_softmax_cross_entropy(self): + """Tests `sequence_sparse_softmax_cross_entropy` + """ + self._test_sequence_loss( + tx.losses.sequence_sparse_softmax_cross_entropy, + self._labels, self._logits, self._sequence_length) + + def test_sequence_sigmoid_cross_entropy(self): + """Tests `texar.losses.test_sequence_sigmoid_cross_entropy`. + """ + self._test_sequence_loss( + tx.losses.sequence_sigmoid_cross_entropy, + self._one_hot_labels, self._logits, self._sequence_length) + + self._test_sequence_loss( + tx.losses.sequence_sigmoid_cross_entropy, + self._one_hot_labels[:, :, 0], + self._logits[:, :, 0], + self._sequence_length) + + loss = tx.losses.sequence_sigmoid_cross_entropy( + logits=self._logits[:, :, 0], + labels=np.ones([self._batch_size, self._max_time]), + sequence_length=self._sequence_length) + rank = len(loss.shape) + self.assertEqual(rank, 0) + if __name__ == "__main__": unittest.main() From 46e356c13c0a70a3a6e9b551ef8f2d6c0a525032 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Mon, 6 May 2019 13:57:51 -0400 Subject: [PATCH 11/24] add entropy.py in losses --- texar/losses/entropy.py | 207 ++++++++++++++++++++++++++++++++ texar/losses/mle_losses_test.py | 1 + 2 files changed, 208 insertions(+) create mode 100644 texar/losses/entropy.py diff --git a/texar/losses/entropy.py b/texar/losses/entropy.py new file mode 100644 index 000000000..8dca1717e --- /dev/null +++ b/texar/losses/entropy.py @@ -0,0 +1,207 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Various entropies. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn.functional as F + +from texar.losses.losses_utils import mask_and_reduce, reduce_dimensions +from texar.utils.shapes import get_rank + +# pylint: disable=too-many-arguments + +__all__ = [ + "entropy_with_logits", + "sequence_entropy_with_logits" +] + + +def _get_entropy(logits): + probs = F.softmax(logits) + 1e-8 + entropy = - probs * torch.log(probs) + entropy = torch.sum(entropy, -1) + return entropy + + +def entropy_with_logits(logits, + rank=None, + average_across_batch=True, + average_across_remaining=False, + sum_over_batch=False, + sum_over_remaining=True): + """Shannon entropy given logits. + + Args: + logits: Unscaled log probabilities of shape + `[batch_size, d_2, ..., d_{rank-1}, distribution_dim]` + and of dtype `float32` or `float64`. + + The rank of the tensor is optionally specified by the argument + :attr:`rank`. + + The tensor is considered as having `[batch_size, .., d_{rank-1}]` + elements, each of which has a distribution of length `d_rank` + (i.e., `distribution_dim`). So the last dimension is always + summed out to compute the entropy. + rank (int, optional): The rank of :attr:`logits`. + If `None` (default), `rank` is inferred automatically from + `logits`. If the inference fails, `rank` is + set to 2, i.e., assuming :attr:`logits` is of shape + `[batch_size, distribution_dim]` + average_across_batch (bool): If set, average the entropy across the + batch dimension. Must not set `average_across_batch`' + and `sum_over_batch` at the same time. + average_across_remaining (bool): If set, average the entropy across the + remaining dimensions. Must not set `average_across_remaining`' + and `sum_over_remaining` at the same time. + Used only when :attr:`logits` has rank >= 3. + sum_over_batch (bool): If set, sum the entropy across the + batch dimension. Must not set `average_across_batch` + and `sum_over_batch` at the same time. + sum_over_remaining (bool): If set, sum the entropy across the + remaining dimension. Must not set `average_across_remaining` + and `sum_over_remaining` at the same time. + Used only when :attr:`logits` has rank >= 3. + + Returns: + A Tensor containing the shannon entropy. The dimensionality of the + Tensor depends on the configuration of reduction arguments. For + example, if both batch and remaining dimensions are reduced (by + either sum or average), the returned Tensor is a scalar Tensor. + """ + entropy = _get_entropy(logits) + + if rank is None: + rank = get_rank(logits) + if rank is None: + rank = 2 + rank -= 1 + + if average_across_batch and sum_over_batch: + raise ValueError("Only one of `average_across_batch` and " + "`sum_over_batch` can be set.") + if average_across_remaining and sum_over_remaining: + raise ValueError("Only one of `average_across_remaining` and " + "`sum_over_remaining` can be set.") + sum_axes, average_axes = [], [] + if sum_over_batch: + sum_axes.append(0) + if average_across_batch: + average_axes.append(0) + if sum_over_remaining and rank >= 2: + sum_axes += list(range(1, rank)) + if average_across_remaining and rank >= 2: + average_axes += list(range(1, rank)) + + entropy = reduce_dimensions( + entropy, average_axes=average_axes, sum_axes=sum_axes + ) + + return entropy + + +def sequence_entropy_with_logits(logits, + rank=None, + sequence_length=None, + average_across_batch=True, + average_across_timesteps=False, + average_across_remaining=False, + sum_over_batch=False, + sum_over_timesteps=True, + sum_over_remaining=True, + time_major=False): + """Shannon entropy given logits. + + Args: + logits: Unscaled log probabilities of shape + `[batch_size, max_time, d_3, ..., d_{rank-1}, distribution_dim]` + and of dtype `float32` or `float64`. + + The rank of the tensor is optionally specified by the argument + :attr:`rank`. + + The tensor is considered as having `[batch_size, .., d_{rank-1}]` + elements, each of which has a distribution of length `d_rank` + (i.e., `distribution_dim`). So the last dimension is always + summed out to compute the entropy. + + The batch and time dimensions are exchanged if :attr:`time_major` + is `True`. + rank (int, optional): The rank of :attr:`logits`. + If `None` (default), `rank` is inferred automatically from + `logits`. If the inference fails, `rank` is + set to 3, i.e., assuming `logits` is of shape + `[batch_size, max_time, distribution_dim]` + sequence_length (optional): A Tensor of shape `[batch_size]`. + Time steps beyond the respective sequence lengths are + counted into the entropy. + average_across_timesteps (bool): If set, average the entropy across + the time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + average_across_batch (bool): If set, average the entropy across the + batch dimension. Must not set `average_across_batch`' + and `sum_over_batch` at the same time. + average_across_remaining (bool): If set, average the entropy across the + remaining dimensions. Must not set `average_across_remaining`' + and `sum_over_remaining` at the same time. + Used only when :attr:`logits` has rank >= 4. + sum_over_timesteps (bool): If set, sum the entropy across the + time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + sum_over_batch (bool): If set, sum the entropy across the + batch dimension. Must not set `average_across_batch` + and `sum_over_batch` at the same time. + sum_over_remaining (bool): If set, sum the entropy across the + remaining dimension. Must not set `average_across_remaining` + and `sum_over_remaining` at the same time. + Used only when :attr:`logits` has rank >= 4. + time_major (bool): The shape format of the inputs. If `True`, + :attr:`logits` must have shape `[max_time, batch_size, ...]`. + If `False` (default), it must have shape + `[batch_size, max_time, ...]`. + + Returns: + A Tensor containing the shannon entropy. The dimensionality of the + Tensor depends on the configuration of reduction arguments. For + example, if batch, time, and remaining dimensions are all reduced (by + either sum or average), the returned Tensor is a scalar Tensor. + """ + entropy = _get_entropy(logits) + + if rank is None: + rank = get_rank(logits) + if rank is None: + rank = 3 + rank -= 1 + + entropy = mask_and_reduce( + entropy, + sequence_length, + rank=rank, + average_across_batch=average_across_batch, + average_across_timesteps=average_across_timesteps, + average_across_remaining=average_across_remaining, + sum_over_batch=sum_over_batch, + sum_over_timesteps=sum_over_timesteps, + sum_over_remaining=sum_over_remaining, + time_major=time_major + ) + + return entropy diff --git a/texar/losses/mle_losses_test.py b/texar/losses/mle_losses_test.py index b5d32c007..5fe6a2b56 100644 --- a/texar/losses/mle_losses_test.py +++ b/texar/losses/mle_losses_test.py @@ -15,6 +15,7 @@ import torch import torch.nn.functional as F +import numpy as np import texar as tx From 714cdf8bb4131fee844df712152f34fb22c70285 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Tue, 7 May 2019 15:17:52 -0400 Subject: [PATCH 12/24] add entropy_test --- texar/losses/entropy_test.py | 102 ++++++++++++++++++++++++++++++++ texar/losses/mle_losses_test.py | 6 +- 2 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 texar/losses/entropy_test.py diff --git a/texar/losses/entropy_test.py b/texar/losses/entropy_test.py new file mode 100644 index 000000000..4b27ada54 --- /dev/null +++ b/texar/losses/entropy_test.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +# +""" +Unit tests for entropy. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +# pylint: disable=invalid-name + +import unittest +import torch +import torch.nn.functional as F + +import numpy as np +import texar as tx + + +class EntropyTest(unittest.TestCase): + """Tests entropy. + """ + + def setUp(self): + self._batch_size = 64 + self._max_time = 128 + self._d = 16 + self._distribution_dim = 32 + self._logits = torch.rand(self._batch_size, self._d, + self._distribution_dim) + self._sequence_logits = torch.rand(self._batch_size, + self._max_time, + self._d, + self._distribution_dim) + self._sequence_length = torch.randint(size=(self._batch_size,), + high=self._max_time) + + def _test_entropy(self, entropy_fn, logits, sequence_length=None): + if sequence_length is None: + entropy = entropy_fn(logits) + rank = len(entropy.shape) + self.assertEqual(rank, 0) + + entropy = entropy_fn(logits, average_across_batch=False) + rank = len(entropy.shape) + self.assertEqual(rank, 1) + self.assertEqual(entropy.shape, torch.Size([self._batch_size])) + else: + entropy = entropy_fn(logits, sequence_length=sequence_length) + rank = len(entropy.shape) + self.assertEqual(rank, 0) + + entropy = entropy_fn(logits, sequence_length=sequence_length, + sum_over_timesteps=False) + rank = len(entropy.shape) + self.assertEqual(rank, 1) + self.assertEqual(entropy.shape, torch.Size([self._max_time])) + + entropy = entropy_fn(logits, sequence_length=sequence_length, + sum_over_timesteps=False, + average_across_timesteps=True, + average_across_batch=False) + rank = len(entropy.shape) + self.assertEqual(rank, 1) + self.assertEqual(entropy.shape, torch.Szie([self._batch_size])) + + entropy = entropy_fn(logits, sequence_length=sequence_length, + sum_over_timesteps=False, + average_across_batch=False) + rank = len(entropy.shape) + self.assertEqual(rank, 2) + self.assertEqual(entropy.shape, torch.Szie([self._batch_size, + self._max_time])) + + sequence_length_time = torch.randint(size=(self._max_time,), + high=self._batch_size) + entropy = entropy_fn(logits, + sequence_length=sequence_length_time, + sum_over_timesteps=False, + average_across_batch=False, + time_major=True) + self.assertEqual(entropy.shape, torch.Size([self._batch_size, + self._max_time])) + + def test_entropy_with_logits(self): + """Tests `entropy_with_logits` + """ + self._test_entropy( + tx.losses.entropy_with_logits, self._logits) + + def test_sequence_entropy_with_logits(self): + """Tests `sequence_entropy_with_logits` + """ + self._test_entropy( + tx.losses.sequence_entropy_with_logits, self._sequence_logits, + sequence_length=self._sequence_length) + + +if __name__ == "__main__": + unittest.main() diff --git a/texar/losses/mle_losses_test.py b/texar/losses/mle_losses_test.py index 5fe6a2b56..0f153ffc5 100644 --- a/texar/losses/mle_losses_test.py +++ b/texar/losses/mle_losses_test.py @@ -34,7 +34,8 @@ def setUp(self): one_hot_labels, [self._batch_size, self._max_time, -1]) self._logits = torch.rand(self._batch_size, self._max_time, self._num_classes) - self._sequence_length = torch.rand(self._batch_size) * self._max_time + self._sequence_length = torch.randint(size=(self._batch_size,), + high=self._max_time) def _test_sequence_loss(self, loss_fn, labels, logits, sequence_length): loss = loss_fn(labels, logits, sequence_length) @@ -62,7 +63,8 @@ def _test_sequence_loss(self, loss_fn, labels, logits, sequence_length): self.assertEqual(loss.shape, torch.Size([self._batch_size, self._max_time])) - sequence_length_time = torch.rand(self._max_time) * self._max_time + sequence_length_time = torch.randint(size=[self._max_time], + high=self._batch_size) loss = loss_fn( labels, logits, sequence_length_time, sum_over_timesteps=False, average_across_batch=False, time_major=True) From 2119f844d67aafc949d851f27a41a3c65049d9f1 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Tue, 7 May 2019 15:22:23 -0400 Subject: [PATCH 13/24] fix error in mle_losses_test --- texar/losses/mle_losses_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/texar/losses/mle_losses_test.py b/texar/losses/mle_losses_test.py index 0f153ffc5..e3dd211f1 100644 --- a/texar/losses/mle_losses_test.py +++ b/texar/losses/mle_losses_test.py @@ -63,7 +63,7 @@ def _test_sequence_loss(self, loss_fn, labels, logits, sequence_length): self.assertEqual(loss.shape, torch.Size([self._batch_size, self._max_time])) - sequence_length_time = torch.randint(size=[self._max_time], + sequence_length_time = torch.randint(size=(self._max_time,), high=self._batch_size) loss = loss_fn( labels, logits, sequence_length_time, sum_over_timesteps=False, From 5270df7b3a4c451f33fd7bbed76f0d05695ae0ce Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Wed, 8 May 2019 11:04:06 -0400 Subject: [PATCH 14/24] bugfix in entropy --- texar/__init__.py | 1 + texar/losses/__init__.py | 4 ++++ texar/losses/entropy_test.py | 17 ++++++++--------- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/texar/__init__.py b/texar/__init__.py index 681227fcd..9d7d62f7e 100644 --- a/texar/__init__.py +++ b/texar/__init__.py @@ -19,5 +19,6 @@ from texar import core from texar import utils +from texar import losses from texar.hyperparams import * from texar.module_base import * diff --git a/texar/losses/__init__.py b/texar/losses/__init__.py index 26b591f9b..051f87eba 100644 --- a/texar/losses/__init__.py +++ b/texar/losses/__init__.py @@ -20,3 +20,7 @@ from __future__ import print_function # pylint: disable=wildcard-import + +from texar.losses.entropy import * +from texar.losses.mle_losses import * +from texar.losses.losses_utils import * diff --git a/texar/losses/entropy_test.py b/texar/losses/entropy_test.py index 4b27ada54..98f65ced3 100644 --- a/texar/losses/entropy_test.py +++ b/texar/losses/entropy_test.py @@ -13,11 +13,10 @@ import unittest import torch -import torch.nn.functional as F - -import numpy as np import texar as tx +from texar.utils.shapes import get_rank + class EntropyTest(unittest.TestCase): """Tests entropy. @@ -40,21 +39,21 @@ def setUp(self): def _test_entropy(self, entropy_fn, logits, sequence_length=None): if sequence_length is None: entropy = entropy_fn(logits) - rank = len(entropy.shape) + rank = entropy.dim() self.assertEqual(rank, 0) entropy = entropy_fn(logits, average_across_batch=False) - rank = len(entropy.shape) + rank = entropy.dim() self.assertEqual(rank, 1) self.assertEqual(entropy.shape, torch.Size([self._batch_size])) else: entropy = entropy_fn(logits, sequence_length=sequence_length) - rank = len(entropy.shape) + rank = entropy.dim() self.assertEqual(rank, 0) entropy = entropy_fn(logits, sequence_length=sequence_length, sum_over_timesteps=False) - rank = len(entropy.shape) + rank = entropy.dim() self.assertEqual(rank, 1) self.assertEqual(entropy.shape, torch.Size([self._max_time])) @@ -62,14 +61,14 @@ def _test_entropy(self, entropy_fn, logits, sequence_length=None): sum_over_timesteps=False, average_across_timesteps=True, average_across_batch=False) - rank = len(entropy.shape) + rank = entropy.dim() self.assertEqual(rank, 1) self.assertEqual(entropy.shape, torch.Szie([self._batch_size])) entropy = entropy_fn(logits, sequence_length=sequence_length, sum_over_timesteps=False, average_across_batch=False) - rank = len(entropy.shape) + rank = entropy.dim() self.assertEqual(rank, 2) self.assertEqual(entropy.shape, torch.Szie([self._batch_size, self._max_time])) From c4289e68ea4d2821f5e258adafd77829f1d60204 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Wed, 8 May 2019 13:54:03 -0400 Subject: [PATCH 15/24] make entropy_test and mle_losses_test run correctly --- texar/losses/entropy.py | 2 +- texar/losses/entropy_test.py | 16 ++++++++-------- texar/losses/losses_utils.py | 4 ++-- texar/losses/mle_losses.py | 11 +++++++---- texar/losses/mle_losses_test.py | 18 +++++++++--------- texar/utils/shapes.py | 2 +- 6 files changed, 28 insertions(+), 25 deletions(-) diff --git a/texar/losses/entropy.py b/texar/losses/entropy.py index 8dca1717e..208966b3c 100644 --- a/texar/losses/entropy.py +++ b/texar/losses/entropy.py @@ -34,7 +34,7 @@ def _get_entropy(logits): - probs = F.softmax(logits) + 1e-8 + probs = F.softmax(logits, -1) + 1e-8 entropy = - probs * torch.log(probs) entropy = torch.sum(entropy, -1) return entropy diff --git a/texar/losses/entropy_test.py b/texar/losses/entropy_test.py index 98f65ced3..59638e8ca 100644 --- a/texar/losses/entropy_test.py +++ b/texar/losses/entropy_test.py @@ -39,21 +39,21 @@ def setUp(self): def _test_entropy(self, entropy_fn, logits, sequence_length=None): if sequence_length is None: entropy = entropy_fn(logits) - rank = entropy.dim() + rank = get_rank(entropy) self.assertEqual(rank, 0) entropy = entropy_fn(logits, average_across_batch=False) - rank = entropy.dim() + rank = get_rank(entropy) self.assertEqual(rank, 1) self.assertEqual(entropy.shape, torch.Size([self._batch_size])) else: entropy = entropy_fn(logits, sequence_length=sequence_length) - rank = entropy.dim() + rank = get_rank(entropy) self.assertEqual(rank, 0) entropy = entropy_fn(logits, sequence_length=sequence_length, sum_over_timesteps=False) - rank = entropy.dim() + rank = get_rank(entropy) self.assertEqual(rank, 1) self.assertEqual(entropy.shape, torch.Size([self._max_time])) @@ -61,16 +61,16 @@ def _test_entropy(self, entropy_fn, logits, sequence_length=None): sum_over_timesteps=False, average_across_timesteps=True, average_across_batch=False) - rank = entropy.dim() + rank = get_rank(entropy) self.assertEqual(rank, 1) - self.assertEqual(entropy.shape, torch.Szie([self._batch_size])) + self.assertEqual(entropy.shape, torch.Size([self._batch_size])) entropy = entropy_fn(logits, sequence_length=sequence_length, sum_over_timesteps=False, average_across_batch=False) - rank = entropy.dim() + rank = get_rank(entropy) self.assertEqual(rank, 2) - self.assertEqual(entropy.shape, torch.Szie([self._batch_size, + self.assertEqual(entropy.shape, torch.Size([self._batch_size, self._max_time])) sequence_length_time = torch.randint(size=(self._max_time,), diff --git a/texar/losses/losses_utils.py b/texar/losses/losses_utils.py index 079e9e437..dea4bc10d 100644 --- a/texar/losses/losses_utils.py +++ b/texar/losses/losses_utils.py @@ -203,6 +203,6 @@ def reduce_dimensions(tensor, average_axes=None, sum_axes=None, keepdims=None): raise ValueError('`average_axes` and `sum_axes` must not ' 'have overlapped elements.') if not keepdims: - tensor = torch.squeeze(tensor, dim=list(reduced_axes)) - + for axis in sorted(list(reduced_axes), reverse=True): + tensor = torch.squeeze(tensor, dim=axis) return tensor diff --git a/texar/losses/mle_losses.py b/texar/losses/mle_losses.py index 9289610e0..9510bf9a6 100644 --- a/texar/losses/mle_losses.py +++ b/texar/losses/mle_losses.py @@ -102,7 +102,7 @@ def sequence_softmax_cross_entropy(labels, if stop_gradient_to_label: labels = labels.detach() - losses = torch.sum(- labels * F.log_softmax(logits, -1), -1) + losses = torch.sum(-labels.type(logits.dtype)*F.log_softmax(logits, -1), -1) losses = mask_and_reduce(losses, sequence_length, @@ -188,7 +188,9 @@ def sequence_sparse_softmax_cross_entropy(labels, sequence_length=data_batch['length']-1) """ - losses = F.nll_loss(F.log_softmax(logits, dim=1), labels) + logits = F.log_softmax(logits, dim=2) + logits = logits.permute(0, 2, 1) + losses = F.nll_loss(logits, labels, reduction='none') losses = mask_and_reduce(losses, sequence_length, @@ -274,8 +276,9 @@ class dimension. Must not set `average_across_classes` """ if stop_gradient_to_label: labels = labels.detach() - losses = torch.nn.BCEWithLogitsLoss(reduction=None) - losses = losses(logits, labels) + losses = torch.nn.BCEWithLogitsLoss(reduction='none') + + losses = losses(logits, labels.type(logits.dtype)) rank = shapes.get_rank(logits) or shapes.get_rank(labels) if rank is None: diff --git a/texar/losses/mle_losses_test.py b/texar/losses/mle_losses_test.py index e3dd211f1..736af4085 100644 --- a/texar/losses/mle_losses_test.py +++ b/texar/losses/mle_losses_test.py @@ -14,10 +14,10 @@ import unittest import torch import torch.nn.functional as F - -import numpy as np import texar as tx +from texar.utils.shapes import get_rank + class MLELossesTest(unittest.TestCase): """Tests mle losses. @@ -28,7 +28,7 @@ def setUp(self): self._max_time = 16 self._num_classes = 100 self._labels = torch.ones(self._batch_size, self._max_time, - dtype=torch.int32) + dtype=torch.int64) one_hot_labels = F.one_hot(self._labels, self._num_classes) self._one_hot_labels = torch.reshape( one_hot_labels, [self._batch_size, self._max_time, -1]) @@ -39,26 +39,26 @@ def setUp(self): def _test_sequence_loss(self, loss_fn, labels, logits, sequence_length): loss = loss_fn(labels, logits, sequence_length) - rank = len(loss.shape) + rank = get_rank(loss) self.assertEqual(rank, 0) loss = loss_fn(labels, logits, sequence_length, sum_over_timesteps=False) - rank = len(loss.shape) + rank = get_rank(loss) self.assertEqual(rank, 1) self.assertEqual(loss.shape, torch.Size([self._max_time])) loss = loss_fn( labels, logits, sequence_length, sum_over_timesteps=False, average_across_timesteps=True, average_across_batch=False) - rank = len(loss.shape) + rank = get_rank(loss) self.assertEqual(rank, 1) self.assertEqual(loss.shape, torch.Size([self._batch_size])) loss = loss_fn( labels, logits, sequence_length, sum_over_timesteps=False, average_across_batch=False) - rank = len(loss.shape) + rank = get_rank(loss) self.assertEqual(rank, 2) self.assertEqual(loss.shape, torch.Size([self._batch_size, self._max_time])) @@ -100,9 +100,9 @@ def test_sequence_sigmoid_cross_entropy(self): loss = tx.losses.sequence_sigmoid_cross_entropy( logits=self._logits[:, :, 0], - labels=np.ones([self._batch_size, self._max_time]), + labels=torch.ones([self._batch_size, self._max_time]), sequence_length=self._sequence_length) - rank = len(loss.shape) + rank = get_rank(loss) self.assertEqual(rank, 0) diff --git a/texar/utils/shapes.py b/texar/utils/shapes.py index 4e7b1d665..f2e87b8d9 100644 --- a/texar/utils/shapes.py +++ b/texar/utils/shapes.py @@ -68,7 +68,7 @@ def get_rank(tensor: torch.Tensor) -> int: `None` if the rank cannot be determined. """ if torch.is_tensor(tensor): - rank = len(tensor.dim()) + rank = tensor.dim() else: array = np.asarray(tensor) rank = array.ndim From a8395dbc7ecae682f687395c3a392850ba3a29be Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Wed, 8 May 2019 15:36:57 -0400 Subject: [PATCH 16/24] add adv_losses and adv_losses_test --- texar/losses/__init__.py | 3 +- texar/losses/adv_losses.py | 81 +++++++++++++++++++++++++++++++++ texar/losses/adv_losses_test.py | 44 ++++++++++++++++++ 3 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 texar/losses/adv_losses.py create mode 100644 texar/losses/adv_losses_test.py diff --git a/texar/losses/__init__.py b/texar/losses/__init__.py index 051f87eba..1f4a52335 100644 --- a/texar/losses/__init__.py +++ b/texar/losses/__init__.py @@ -21,6 +21,7 @@ # pylint: disable=wildcard-import +from texar.losses.losses_utils import * from texar.losses.entropy import * from texar.losses.mle_losses import * -from texar.losses.losses_utils import * +from texar.losses.adv_losses import * diff --git a/texar/losses/adv_losses.py b/texar/losses/adv_losses.py new file mode 100644 index 000000000..4894f3c6c --- /dev/null +++ b/texar/losses/adv_losses.py @@ -0,0 +1,81 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adversarial losses. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch + + +def binary_adversarial_losses(real_data, + fake_data, + discriminator_fn, + mode="max_real"): + """Computes adversarial losses of real/fake binary discrimination game. + + .. role:: python(code) + :language: python + + Args: + real_data (Tensor or array): Real data of shape + `[num_real_examples, ...]`. + fake_data (Tensor or array): Fake data of shape + `[num_fake_examples, ...]`. `num_real_examples` does not + necessarily equal `num_fake_examples`. + discriminator_fn: A callable takes data (e.g., :attr:`real_data` and + :attr:`fake_data`) and returns the logits of being real. The + signature of `discriminator_fn` must be: + :python:`logits, ... = discriminator_fn(data)`. + The return value of `discriminator_fn` can be the logits, or + a tuple where the logits are the first element. + + mode (str): Mode of the generator loss. Either "max_real" or "min_fake". + + - **"max_real"** (default): minimizing the generator loss is to\ + maximize the probability of fake data being classified as real. + + - **"min_fake"**: minimizing the generator loss is to minimize the\ + probability of fake data being classified as fake. + + Returns: + A tuple `(generator_loss, discriminator_loss)` each of which is + a scalar Tensor, loss to be minimized. + """ + real_logits = discriminator_fn(real_data) + if isinstance(real_logits, (list, tuple)): + real_logits = real_logits[0] + real_loss = torch.nn.BCEWithLogitsLoss(reduction='mean') + real_loss = real_loss(real_logits, torch.ones_like(real_logits)) + + fake_logits = discriminator_fn(fake_data) + if isinstance(fake_logits, (list, tuple)): + fake_logits = fake_logits[0] + fake_loss = torch.nn.BCEWithLogitsLoss(reduction='mean') + fake_loss = fake_loss(fake_logits, torch.zeros_like(fake_logits)) + + d_loss = real_loss + fake_loss + + if mode == "min_fake": + g_loss = - fake_loss + elif mode == "max_real": + g_loss = torch.nn.BCEWithLogitsLoss(reduction='mean') + g_loss = g_loss(fake_logits, torch.ones_like(fake_logits)) + else: + raise ValueError("Unknown mode: %s. Only 'min_fake' and 'max_real' " + "are allowed.") + + return g_loss, d_loss diff --git a/texar/losses/adv_losses_test.py b/texar/losses/adv_losses_test.py new file mode 100644 index 000000000..93f287c23 --- /dev/null +++ b/texar/losses/adv_losses_test.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# +""" +Unit tests for adv_losses. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +# pylint: disable=invalid-name + +import unittest +import torch +import texar as tx + + +class AdvLossesTest(unittest.TestCase): + """Tests adversarial losses. + """ + + def test_binary_adversarial_losses(self): + """Tests :meth:`~texar.losses.adv_losses.binary_adversarial_losse`. + """ + batch_size = 16 + data_dim = 64 + real_data = torch.zeros(size=(batch_size, data_dim), + dtype=torch.float32) + fake_data = torch.ones(size=(batch_size, data_dim), + dtype=torch.float32) + const_logits = torch.zeros(size=(batch_size,), dtype=torch.float32) + # Use a dumb discriminator that always outputs logits=0. + gen_loss, disc_loss = tx.losses.binary_adversarial_losses( + real_data, fake_data, lambda x: const_logits) + gen_loss_2, disc_loss_2 = tx.losses.binary_adversarial_losses( + real_data, fake_data, lambda x: const_logits, mode="min_fake") + + self.assertAlmostEqual(gen_loss.item(), -gen_loss_2.item()) + self.assertAlmostEqual(disc_loss.item(), disc_loss_2.item()) + + +if __name__ == "__main__": + unittest.main() From 012bc8abc46150cb0136b1c4459b4acbb869a092 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Wed, 8 May 2019 16:37:02 -0400 Subject: [PATCH 17/24] add pg_loss_with_log_probs --- texar/losses/pg_losses.py | 243 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 texar/losses/pg_losses.py diff --git a/texar/losses/pg_losses.py b/texar/losses/pg_losses.py new file mode 100644 index 000000000..8b97823e1 --- /dev/null +++ b/texar/losses/pg_losses.py @@ -0,0 +1,243 @@ +# Copyright 2018 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Various loss functions for policy gradients. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch + +from texar.losses.losses_utils import mask_and_reduce +from texar.utils.shapes import get_rank + +# pylint: disable=too-many-arguments, protected-access + +__all__ = [ + "pg_loss_with_logits", + "pg_loss_with_log_probs" +] + + +def pg_loss_with_logits(actions, + logits, + advantages, + rank=None, + batched=False, + sequence_length=None, + average_across_batch=True, + average_across_timesteps=False, + average_across_remaining=False, + sum_over_batch=False, + sum_over_timesteps=True, + sum_over_remaining=True, + time_major=False): + """Policy gradient loss with logits. Used for discrete actions. + + `pg_loss = reduce( advantages * -log_prob( actions ) )`, + where `advantages` and `actions` do not back-propagate gradients. + + All arguments except :attr:`logits` and :attr:`actions` are the same with + :func:`pg_loss_with_log_probs`. + + Args: + actions: Tensor of shape + `[(batch_size,) max_time, d_3, ..., d_rank]` and of dtype + `int32` or `int64`. + The rank of the Tensor is specified with :attr:`rank`. + + The batch dimension exists only if :attr:`batched` is `True`. + + The batch and time dimensions + are exchanged, i.e., `[max_time, batch_size, ...]` if + :attr:`time_major` is `True`. + logits: Unscaled log probabilities of shape + `[(batch_size,) max_time, d_3, ..., d_{rank+1}]` + and dtype `float32` or `float64`. + The batch and time dimensions are exchanged if `time_major` + is `True`. + advantages: Tensor of shape + `[(batch_size,) max_time, d_3, ..., d_rank]` and + dtype `float32` or `float64`. + The batch and time dimensions are exchanged if `time_major` + is `True`. + rank (int, optional): The rank of :attr:`actions`. + If `None` (default), rank is automatically inferred from + `actions` or `advantages`. If the inference fails, + `rank` is set to 1 if :attr:`batched` is `False`, + and set to 2 if :attr:`batched` is `True`. + batched (bool): `True` if the inputs are batched. + sequence_length (optional): A Tensor of shape `[batch_size]`. + Time steps beyond the respective sequence lengths will have zero + losses. Used if :attr:`batched` is `True`. + average_across_timesteps (bool): If set, average the loss across + the time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + average_across_batch (bool): If set, average the loss across the + batch dimension. Must not set `average_across_batch`' + and `sum_over_batch` at the same time. + Ignored if `batched` is `False`. + average_across_remaining (bool): If set, average the sequence across the + remaining dimensions. Must not set `average_across_remaining`' + and `sum_over_remaining` at the same time. Ignored if + no more dimensions other than the batch and time dimensions. + sum_over_timesteps (bool): If set, sum the loss across the + time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + sum_over_batch (bool): If set, sum the loss across the + batch dimension. Must not set `average_across_batch` + and `sum_over_batch` at the same time. + Ignored if `batched` is `False`. + sum_over_remaining (bool): If set, sum the loss across the + remaining dimension. Must not set `average_across_remaining` + and `sum_over_remaining` at the same time. Ignored if + no more dimensions other than the batch and time dimensions. + time_major (bool): The shape format of the inputs. If `True`, + :attr:`logits`, :attr:`actions` and :attr:`advantages` must + have shape `[max_time, batch_size, ...]`. If `False` (default), + they must have shape `[batch_size, max_time, ...]`. + Ignored if `batched` is `False`. + + Returns: + A Tensor containing the loss to minimize, whose rank depends on the + reduce arguments. For example, the batch dimension is reduced if + either :attr:`average_across_batch` or :attr:`sum_over_batch` is + `True`, which decreases the rank of output tensor by 1. + """ + return + + +def pg_loss_with_log_probs(log_probs, + advantages, + rank=None, + batched=False, + sequence_length=None, + average_across_batch=True, + average_across_timesteps=False, + average_across_remaining=False, + sum_over_batch=False, + sum_over_timesteps=True, + sum_over_remaining=True, + time_major=False): + """Policy gradient loss with log probs of actions. + + `pg_loss = reduce( advantages * -log_probs )`, + where `advantages` does not back-propagate gradients. + + All arguments except :attr:`log_probs` are the same as + :func:`pg_loss_with_logits`. + + Args: + log_probs: Log probabilities of shape + `[(batch_size,) max_time, ..., d_rank]` and dtype `float32` + or `float64`. The rank of the Tensor is specified + with :attr:`rank`. + + The batch dimension exists only if :attr:`batched` is `True`. + + The batch and time dimensions are exchanged, i.e., + `[max_time, batch_size, ...]` if :attr:`time_major` is `True`. + advantages: Tensor of shape + `[(batch_size,) max_time, d_3, ..., d_rank]` and + dtype `float32` or `float64`. + The batch dimension exists only if `batched` is `True`. + The batch and time dimensions + are exchanged if `time_major` is `True`. + rank (int, optional): The rank of :attr:`log_probs`. + If `None` (default), rank is automatically inferred from + `log_probs` or `advantages`. If the inference fails, + `rank` is set to 1 if `batched``==False`, + and set to 2 if `batched``==True`. + batched (bool): `True` if the inputs are batched. + sequence_length (optional): A Tensor of shape `[batch_size]`. + Time steps beyond the respective sequence lengths will have zero + losses. Used if :attr:`batched` is `True`. + average_across_timesteps (bool): If set, average the loss across + the time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + average_across_batch (bool): If set, average the loss across the + batch dimension. Must not set `average_across_batch`' + and `sum_over_batch` at the same time. + Ignored if `batched` is `False`. + average_across_remaining (bool): If set, average the sequence across the + remaining dimensions. Must not set `average_across_remaining`' + and `sum_over_remaining` at the same time. Ignored if + no more dimensions other than the batch and time dimensions. + sum_over_timesteps (bool): If set, sum the loss across the + time dimension. Must not set `average_across_timesteps` + and `sum_over_timesteps` at the same time. + sum_over_batch (bool): If set, sum the loss across the + batch dimension. Must not set `average_across_batch` + and `sum_over_batch` at the same time. + Ignored if `batched` is `False`. + sum_over_remaining (bool): If set, sum the loss across the + remaining dimension. Must not set `average_across_remaining` + and `sum_over_remaining` at the same time. Ignored if + no more dimensions other than the batch and time dimensions. + time_major (bool): The shape format of the inputs. If `True`, + :attr:`log_probs` and :attr:`advantages` must have shape + `[max_time, batch_size, ...]`. If `False` (default), + they must have shape `[batch_size, max_time, ...]`. + Ignored if :attr:`batched` is `False`. + + Returns: + A Tensor containing the loss to minimize, whose rank depends on the + reduce arguments. For example, the batch dimension is reduced if + either :attr:`average_across_batch` or :attr:`sum_over_batch` is + `True`, which decreases the rank of output tensor by 1. + """ + advantages = advantages.detach() + + losses = -log_probs * advantages + + if rank is None: + rank = get_rank(log_probs) or get_rank(advantages) + if rank is None: + rank = 2 if batched else 1 + + if batched: + losses = mask_and_reduce( + losses, + sequence_length, + rank=rank, + average_across_batch=average_across_batch, + average_across_timesteps=average_across_timesteps, + average_across_remaining=average_across_remaining, + sum_over_batch=sum_over_batch, + sum_over_timesteps=sum_over_timesteps, + sum_over_remaining=sum_over_remaining, + time_major=time_major) + elif rank > 1: + if average_across_remaining and sum_over_remaining: + raise ValueError("Only one of `average_across_remaining` and " + "`sum_over_remaining` can be set.") + if average_across_remaining: + for average_axis in range(1, rank): + losses = torch.mean(losses, dim=average_axis) + elif sum_over_remaining: + for sum_axis in range(1, rank): + losses = torch.sum(losses, dim=sum_axis) + + if not batched: + if average_across_timesteps and sum_over_timesteps: + raise ValueError("Only one of `average_across_timesteps` and " + "`sum_over_timesteps` can be set.") + if average_across_timesteps: + losses = torch.mean(losses, dim=0) + elif sum_over_timesteps: + losses = torch.sum(losses, dim=0) + + return losses From c1a3b19b8efdf2d2794c479fd6b88c3c1afd991c Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Fri, 10 May 2019 17:26:04 -0400 Subject: [PATCH 18/24] add pg_losses --- texar/losses/pg_losses.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/texar/losses/pg_losses.py b/texar/losses/pg_losses.py index 8b97823e1..0e6f1fcc4 100644 --- a/texar/losses/pg_losses.py +++ b/texar/losses/pg_losses.py @@ -20,6 +20,7 @@ from __future__ import print_function import torch +import torch.nn.functional as F from texar.losses.losses_utils import mask_and_reduce from texar.utils.shapes import get_rank @@ -117,7 +118,24 @@ def pg_loss_with_logits(actions, either :attr:`average_across_batch` or :attr:`sum_over_batch` is `True`, which decreases the rank of output tensor by 1. """ - return + actions = actions.detach() + logits = F.log_softmax(logits, dim=-1) + logits = logits.permute([0, -1] + list(range(1, logits.dim()-1))) + neg_log_probs = F.nll_loss(logits, actions, reduction='none') + + return pg_loss_with_log_probs( + log_probs=-neg_log_probs, + advantages=advantages, + rank=rank, + batched=batched, + sequence_length=sequence_length, + average_across_batch=average_across_batch, + average_across_timesteps=average_across_timesteps, + average_across_remaining=average_across_remaining, + sum_over_batch=sum_over_batch, + sum_over_timesteps=sum_over_timesteps, + sum_over_remaining=sum_over_remaining, + time_major=time_major) def pg_loss_with_log_probs(log_probs, From b970e2d4d6d25899fa0dce6dc88d99e24f642fc2 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Mon, 13 May 2019 15:21:02 -0400 Subject: [PATCH 19/24] fix bug in pg_loss --- texar/losses/__init__.py | 1 + texar/losses/losses_utils.py | 4 +- texar/losses/mle_losses_test.py | 2 +- texar/losses/pg_losses.py | 4 +- texar/losses/pg_losses_test.py | 121 ++++++++++++++++++++++++++++++++ 5 files changed, 127 insertions(+), 5 deletions(-) create mode 100644 texar/losses/pg_losses_test.py diff --git a/texar/losses/__init__.py b/texar/losses/__init__.py index 1f4a52335..ae0ceb1fa 100644 --- a/texar/losses/__init__.py +++ b/texar/losses/__init__.py @@ -25,3 +25,4 @@ from texar.losses.entropy import * from texar.losses.mle_losses import * from texar.losses.adv_losses import * +from texar.losses.pg_losses import * diff --git a/texar/losses/losses_utils.py b/texar/losses/losses_utils.py index dea4bc10d..690f13570 100644 --- a/texar/losses/losses_utils.py +++ b/texar/losses/losses_utils.py @@ -108,10 +108,10 @@ def mask_and_reduce(sequence, raise ValueError("Only one of `average_across_remaining` and " "`sum_over_remaining` can be set.") if average_across_remaining: - for axis in range(2, rank): + for axis in sorted(list(range(2, rank)), reverse=True): sequence = torch.mean(sequence, dim=axis) elif sum_over_remaining: - for axis in range(2, rank): + for axis in sorted(list(range(2, rank)), reverse=True): sequence = torch.sum(sequence, dim=axis) sequence = reduce_batch_time(sequence, diff --git a/texar/losses/mle_losses_test.py b/texar/losses/mle_losses_test.py index 736af4085..0c5b71b09 100644 --- a/texar/losses/mle_losses_test.py +++ b/texar/losses/mle_losses_test.py @@ -86,7 +86,7 @@ def test_sequence_sparse_softmax_cross_entropy(self): self._labels, self._logits, self._sequence_length) def test_sequence_sigmoid_cross_entropy(self): - """Tests `texar.losses.test_sequence_sigmoid_cross_entropy`. + """Tests `texar.losses.sequence_sigmoid_cross_entropy`. """ self._test_sequence_loss( tx.losses.sequence_sigmoid_cross_entropy, diff --git a/texar/losses/pg_losses.py b/texar/losses/pg_losses.py index 0e6f1fcc4..7a2928b16 100644 --- a/texar/losses/pg_losses.py +++ b/texar/losses/pg_losses.py @@ -243,10 +243,10 @@ def pg_loss_with_log_probs(log_probs, raise ValueError("Only one of `average_across_remaining` and " "`sum_over_remaining` can be set.") if average_across_remaining: - for average_axis in range(1, rank): + for average_axis in sorted(list(range(1, rank)), reverse=True): losses = torch.mean(losses, dim=average_axis) elif sum_over_remaining: - for sum_axis in range(1, rank): + for sum_axis in sorted(list(range(1, rank)), reverse=True): losses = torch.sum(losses, dim=sum_axis) if not batched: diff --git a/texar/losses/pg_losses_test.py b/texar/losses/pg_losses_test.py new file mode 100644 index 000000000..1a2a468d4 --- /dev/null +++ b/texar/losses/pg_losses_test.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +# +""" +Unit tests for pg losses. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +# pylint: disable=invalid-name + +import unittest +import torch +import texar as tx + +from texar.utils.shapes import get_rank + + +class PGLossesTest(unittest.TestCase): + """Tests pg losses + """ + + def setUp(self): + self._batch_size = 64 + self._max_time = 16 + self._d1 = 32 + self._d2 = 32 + self._d3 = 32 + self._num_classes = 10 + self._actions_batch = torch.ones(self._batch_size, self._max_time, + self._d1, self._d2, self._d3, + dtype=torch.int64) + self._logits_batch = torch.rand(self._batch_size, self._max_time, + self._d1, self._d2, self._d3, + self._num_classes) + self._advantages_batch = torch.rand(self._batch_size, self._max_time, + self._d1, self._d2, self._d3) + self._actions_no_batch = torch.ones(self._max_time, self._d1, self._d2, + self._d3, dtype=torch.int64) + self._logits_no_batch = torch.rand(self._max_time, self._d1, self._d2, + self._d3, self._num_classes) + self._advantages_no_batch = torch.rand(self._max_time, self._d1, + self._d2, self._d3) + self._sequence_length = torch.randint(size=(self._batch_size,), + high=self._max_time) + + def _test_sequence_loss(self, loss_fn, actions, logits, advantages, batched, + sequence_length): + loss = loss_fn(actions, logits, advantages, batched=batched, + sequence_length=sequence_length) + rank = get_rank(loss) + self.assertEqual(rank, 0) + + loss = loss_fn(actions, logits, advantages, batched=batched, + sequence_length=sequence_length, + sum_over_timesteps=False) + rank = get_rank(loss) + self.assertEqual(rank, 1) + self.assertEqual(loss.shape, torch.Size([self._max_time])) + + loss = loss_fn(actions, logits, advantages, batched=batched, + sequence_length=sequence_length, + sum_over_timesteps=False, + average_across_timesteps=True, + average_across_batch=False) + rank = get_rank(loss) + if batched: + self.assertEqual(rank, 1) + self.assertEqual(loss.shape, torch.Size([self._batch_size])) + else: + self.assertEqual(rank, 0) + + loss = loss_fn(actions, logits, advantages, batched=batched, + sequence_length=sequence_length, + sum_over_timesteps=False, + average_across_batch=False) + rank = get_rank(loss) + if batched: + self.assertEqual(rank, 2) + self.assertEqual(loss.shape, + torch.Size([self._batch_size, self._max_time])) + else: + self.assertEqual(rank, 1) + self.assertEqual(loss.shape, + torch.Size([self._max_time])) + + sequence_length_time = torch.randint(size=(self._max_time,), + high=self._batch_size) + loss = loss_fn(actions, logits, advantages, batched=batched, + sequence_length=sequence_length_time, + sum_over_timesteps=False, + average_across_batch=False, + time_major=True) + if batched: + self.assertEqual(loss.shape, torch.Size([self._batch_size, + self._max_time])) + else: + self.assertEqual(loss.shape, torch.Size([self._max_time])) + + def test_pg_loss_with_logits(self): + """Tests `texar.losses.pg_loss_with_logits`. + """ + self._test_sequence_loss(tx.losses.pg_loss_with_logits, + self._actions_batch, + self._logits_batch, + self._advantages_batch, + True, + self._sequence_length) + + self._test_sequence_loss(tx.losses.pg_loss_with_logits, + self._actions_no_batch, + self._logits_no_batch, + self._advantages_no_batch, + False, + self._sequence_length) + + +if __name__ == "__main__": + unittest.main() From 0db4ad471a937905ad08257fde0ea3eab9ab91ce Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Thu, 16 May 2019 11:03:39 -0400 Subject: [PATCH 20/24] resolve comments --- texar/losses/__init__.py | 6 +----- texar/losses/adv_losses.py | 14 ++++++-------- texar/losses/adv_losses_test.py | 5 ----- texar/losses/entropy.py | 6 +----- texar/losses/entropy_test.py | 5 ----- texar/losses/losses_utils.py | 7 +------ texar/losses/mle_losses.py | 27 +++++++-------------------- texar/losses/mle_losses_test.py | 5 ----- texar/losses/pg_losses.py | 6 +----- texar/losses/pg_losses_test.py | 5 ----- 10 files changed, 17 insertions(+), 69 deletions(-) diff --git a/texar/losses/__init__.py b/texar/losses/__init__.py index ae0ceb1fa..91ecc96b3 100644 --- a/texar/losses/__init__.py +++ b/texar/losses/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Texar Authors. All Rights Reserved. +# Copyright 2019 The Texar Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +15,6 @@ Modules of texar losses. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - # pylint: disable=wildcard-import from texar.losses.losses_utils import * diff --git a/texar/losses/adv_losses.py b/texar/losses/adv_losses.py index 4894f3c6c..d16e52002 100644 --- a/texar/losses/adv_losses.py +++ b/texar/losses/adv_losses.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Texar Authors. All Rights Reserved. +# Copyright 2019 The Texar Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +14,9 @@ """ Adversarial losses. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function import torch +import torch.nn.functional as F def binary_adversarial_losses(real_data, @@ -58,14 +56,14 @@ def binary_adversarial_losses(real_data, real_logits = discriminator_fn(real_data) if isinstance(real_logits, (list, tuple)): real_logits = real_logits[0] - real_loss = torch.nn.BCEWithLogitsLoss(reduction='mean') - real_loss = real_loss(real_logits, torch.ones_like(real_logits)) + real_loss = F.binary_cross_entropy_with_logits( + real_logits, torch.ones_like(real_logits)) fake_logits = discriminator_fn(fake_data) if isinstance(fake_logits, (list, tuple)): fake_logits = fake_logits[0] - fake_loss = torch.nn.BCEWithLogitsLoss(reduction='mean') - fake_loss = fake_loss(fake_logits, torch.zeros_like(fake_logits)) + fake_loss = F.binary_cross_entropy_with_logits( + fake_logits, torch.zeros_like(fake_logits)) d_loss = real_loss + fake_loss diff --git a/texar/losses/adv_losses_test.py b/texar/losses/adv_losses_test.py index 93f287c23..0165784a8 100644 --- a/texar/losses/adv_losses_test.py +++ b/texar/losses/adv_losses_test.py @@ -4,11 +4,6 @@ Unit tests for adv_losses. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - # pylint: disable=invalid-name import unittest diff --git a/texar/losses/entropy.py b/texar/losses/entropy.py index 208966b3c..852a8e475 100644 --- a/texar/losses/entropy.py +++ b/texar/losses/entropy.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Texar Authors. All Rights Reserved. +# Copyright 2019 The Texar Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +15,6 @@ Various entropies. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import torch import torch.nn.functional as F diff --git a/texar/losses/entropy_test.py b/texar/losses/entropy_test.py index 59638e8ca..7add6feb2 100644 --- a/texar/losses/entropy_test.py +++ b/texar/losses/entropy_test.py @@ -4,11 +4,6 @@ Unit tests for entropy. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - # pylint: disable=invalid-name import unittest diff --git a/texar/losses/losses_utils.py b/texar/losses/losses_utils.py index 690f13570..703e13a70 100644 --- a/texar/losses/losses_utils.py +++ b/texar/losses/losses_utils.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Texar Authors. All Rights Reserved. +# Copyright 2019 The Texar Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,11 +15,6 @@ Various utilities for losses. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - import torch from texar.utils.shapes import transpose_batch_time, mask_sequences diff --git a/texar/losses/mle_losses.py b/texar/losses/mle_losses.py index 9510bf9a6..17ed1220a 100644 --- a/texar/losses/mle_losses.py +++ b/texar/losses/mle_losses.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Texar Authors. All Rights Reserved. +# Copyright 2019 The Texar Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +15,6 @@ Various losses """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import torch import torch.nn.functional as F @@ -45,8 +41,7 @@ def sequence_softmax_cross_entropy(labels, sum_over_batch=False, sum_over_timesteps=True, time_major=False, - stop_gradient_to_label=False, - name=None): + stop_gradient_to_label=False): """Computes softmax cross entropy for each time step of sequence predictions. @@ -122,8 +117,7 @@ def sequence_sparse_softmax_cross_entropy(labels, average_across_timesteps=False, sum_over_batch=False, sum_over_timesteps=True, - time_major=False, - name=None): + time_major=False): """Computes sparse softmax cross entropy for each time step of sequence predictions. @@ -213,8 +207,7 @@ def sequence_sigmoid_cross_entropy(labels, sum_over_timesteps=True, sum_over_classes=False, time_major=False, - stop_gradient_to_label=False, - name=None): + stop_gradient_to_label=False): """Computes sigmoid cross entropy for each time step of sequence predictions. @@ -281,9 +274,6 @@ class dimension. Must not set `average_across_classes` losses = losses(logits, labels.type(logits.dtype)) rank = shapes.get_rank(logits) or shapes.get_rank(labels) - if rank is None: - raise ValueError( - 'Cannot determine the rank of `logits` or `labels`.') losses = mask_and_reduce(losses, sequence_length, @@ -305,8 +295,7 @@ def binary_sigmoid_cross_entropy(pos_logits=None, average_across_classes=True, sum_over_batch=False, sum_over_classes=False, - return_pos_neg_losses=False, - name=None): + return_pos_neg_losses=False): """Computes sigmoid cross entropy of binary predictions. Args: @@ -384,8 +373,7 @@ def binary_sigmoid_cross_entropy_with_clas(clas_fn, average_across_classes=True, sum_over_batch=False, sum_over_classes=False, - return_pos_neg_losses=False, - name=None): + return_pos_neg_losses=False): """Computes sigmoid cross entropy of binary classifier. .. role:: python(code) @@ -454,5 +442,4 @@ class dimension. Must not set `average_across_classes` average_across_classes=average_across_classes, sum_over_batch=sum_over_batch, sum_over_classes=sum_over_classes, - return_pos_neg_losses=return_pos_neg_losses, - name=name) + return_pos_neg_losses=return_pos_neg_losses) diff --git a/texar/losses/mle_losses_test.py b/texar/losses/mle_losses_test.py index 0c5b71b09..30aeca484 100644 --- a/texar/losses/mle_losses_test.py +++ b/texar/losses/mle_losses_test.py @@ -4,11 +4,6 @@ Unit tests for mle losses. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - # pylint: disable=invalid-name import unittest diff --git a/texar/losses/pg_losses.py b/texar/losses/pg_losses.py index 7a2928b16..c6a193558 100644 --- a/texar/losses/pg_losses.py +++ b/texar/losses/pg_losses.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Texar Authors. All Rights Reserved. +# Copyright 2019 The Texar Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +15,6 @@ Various loss functions for policy gradients. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import torch import torch.nn.functional as F diff --git a/texar/losses/pg_losses_test.py b/texar/losses/pg_losses_test.py index 1a2a468d4..c93e8c6ae 100644 --- a/texar/losses/pg_losses_test.py +++ b/texar/losses/pg_losses_test.py @@ -4,11 +4,6 @@ Unit tests for pg losses. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - # pylint: disable=invalid-name import unittest From c9db6283a69436973f4f1b941af0eb1bb311e833 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Thu, 16 May 2019 11:18:40 -0400 Subject: [PATCH 21/24] replace BCEWithLogitsLoss --- texar/losses/mle_losses.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/texar/losses/mle_losses.py b/texar/losses/mle_losses.py index 17ed1220a..a5d9a9a6d 100644 --- a/texar/losses/mle_losses.py +++ b/texar/losses/mle_losses.py @@ -269,9 +269,8 @@ class dimension. Must not set `average_across_classes` """ if stop_gradient_to_label: labels = labels.detach() - losses = torch.nn.BCEWithLogitsLoss(reduction='none') - - losses = losses(logits, labels.type(logits.dtype)) + losses = F.binary_cross_entropy_with_logits( + logits, labels.type(logits.dtype), reduction='none') rank = shapes.get_rank(logits) or shapes.get_rank(labels) @@ -346,15 +345,15 @@ class dimension. Must not set `average_across_classes` pos_loss = 0 if pos_logits is not None: - pos_loss = torch.nn.BCEWithLogitsLoss(reduction=None) - pos_loss = pos_loss(pos_logits, torch.ones_like(pos_logits)) + pos_loss = F.binary_cross_entropy_with_logits( + pos_logits, torch.ones_like(pos_logits), reduction='none') pos_loss = reduce_dimensions(pos_loss, average_axes, sum_axes) neg_loss = 0 if neg_logits is not None: - neg_loss = torch.nn.BCEWithLogitsLoss(reduction=None) - neg_loss = neg_loss(neg_logits, torch.zeros_like(neg_logits)) + neg_loss = F.binary_cross_entropy_with_logits( + neg_logits, torch.zeros_like(neg_logits), reduction='none') neg_loss = reduce_dimensions(neg_loss, average_axes, sum_axes) From 010821bb6d37d1f55525d56a6d3b6c693ab6b5ea Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Thu, 16 May 2019 16:37:41 -0400 Subject: [PATCH 22/24] add type annotations --- texar/losses/adv_losses.py | 14 ++++-- texar/losses/entropy.py | 34 ++++++------- texar/losses/losses_utils.py | 41 +++++++++------- texar/losses/mle_losses.py | 93 +++++++++++++++++++----------------- texar/losses/pg_losses.py | 52 ++++++++++---------- 5 files changed, 127 insertions(+), 107 deletions(-) diff --git a/texar/losses/adv_losses.py b/texar/losses/adv_losses.py index d16e52002..7a0fa4d24 100644 --- a/texar/losses/adv_losses.py +++ b/texar/losses/adv_losses.py @@ -18,11 +18,17 @@ import torch import torch.nn.functional as F +from typing import List, Union, Callable, Tuple -def binary_adversarial_losses(real_data, - fake_data, - discriminator_fn, - mode="max_real"): + +def binary_adversarial_losses( + real_data: Union[torch.Tensor, List[Union[int, float]]], + fake_data: Union[torch.Tensor, List[Union[int, float]]], + discriminator_fn: Callable[[Union[torch.Tensor, + List[Union[int, float]]]], + Union[torch.Tensor, List[torch.Tensor], + Tuple[torch.Tensor]]], + mode: str = "max_real") -> Tuple[torch.Tensor, torch.Tensor]: """Computes adversarial losses of real/fake binary discrimination game. .. role:: python(code) diff --git a/texar/losses/entropy.py b/texar/losses/entropy.py index 852a8e475..48f1180f0 100644 --- a/texar/losses/entropy.py +++ b/texar/losses/entropy.py @@ -29,19 +29,19 @@ ] -def _get_entropy(logits): +def _get_entropy(logits: torch.Tensor) -> torch.Tensor: probs = F.softmax(logits, -1) + 1e-8 entropy = - probs * torch.log(probs) entropy = torch.sum(entropy, -1) return entropy -def entropy_with_logits(logits, - rank=None, - average_across_batch=True, - average_across_remaining=False, - sum_over_batch=False, - sum_over_remaining=True): +def entropy_with_logits(logits: torch.Tensor, + rank: int = None, + average_across_batch: bool = True, + average_across_remaining: bool = False, + sum_over_batch: bool = False, + sum_over_remaining: bool = True) -> torch.Tensor: """Shannon entropy given logits. Args: @@ -113,16 +113,16 @@ def entropy_with_logits(logits, return entropy -def sequence_entropy_with_logits(logits, - rank=None, - sequence_length=None, - average_across_batch=True, - average_across_timesteps=False, - average_across_remaining=False, - sum_over_batch=False, - sum_over_timesteps=True, - sum_over_remaining=True, - time_major=False): +def sequence_entropy_with_logits(logits: torch.Tensor, + rank: int = None, + sequence_length: torch.Tensor = None, + average_across_batch: bool = True, + average_across_timesteps: bool = False, + average_across_remaining: bool = False, + sum_over_batch: bool = False, + sum_over_timesteps: bool = True, + sum_over_remaining: bool = True, + time_major: bool = False) -> torch.Tensor: """Shannon entropy given logits. Args: diff --git a/texar/losses/losses_utils.py b/texar/losses/losses_utils.py index 703e13a70..121d6683f 100644 --- a/texar/losses/losses_utils.py +++ b/texar/losses/losses_utils.py @@ -19,6 +19,8 @@ from texar.utils.shapes import transpose_batch_time, mask_sequences +from typing import Optional, Union, List + # pylint: disable=invalid-name, not-context-manager, protected-access, # pylint: disable=too-many-arguments @@ -29,17 +31,17 @@ ] -def mask_and_reduce(sequence, - sequence_length, - rank=2, - average_across_batch=True, - average_across_timesteps=False, - average_across_remaining=False, - sum_over_batch=False, - sum_over_timesteps=True, - sum_over_remaining=True, - dtype=None, - time_major=False): +def mask_and_reduce(sequence: torch.Tensor, + sequence_length: torch.Tensor, + rank: int = 2, + average_across_batch: bool = True, + average_across_timesteps: bool = False, + average_across_remaining: bool = False, + sum_over_batch: bool = False, + sum_over_timesteps: bool = True, + sum_over_remaining: bool = True, + dtype: Optional[torch.Tensor] = None, + time_major: bool = False) -> torch.Tensor: """Masks out sequence entries that are beyond the respective sequence lengths, and reduces (average or sum) away dimensions. @@ -124,12 +126,12 @@ def mask_and_reduce(sequence, return sequence -def reduce_batch_time(sequence, - sequence_length, - average_across_batch=True, - average_across_timesteps=False, - sum_over_batch=False, - sum_over_timesteps=True): +def reduce_batch_time(sequence: torch.Tensor, + sequence_length: torch.Tensor, + average_across_batch: bool = True, + average_across_timesteps: bool = False, + sum_over_batch: bool = False, + sum_over_timesteps: bool = True) -> torch.Tensor: """Average or sum over the respective dimensions of :attr:`sequence`, which is of shape `[batch_size, max_time, ...]`. @@ -160,7 +162,10 @@ def reduce_batch_time(sequence, return sequence -def reduce_dimensions(tensor, average_axes=None, sum_axes=None, keepdims=None): +def reduce_dimensions(tensor: torch.Tensor, + average_axes: Optional[Union[int, List[int]]] = None, + sum_axes: Optional[Union[int, List[int]]] = None, + keepdims: Optional[bool] = None) -> torch.Tensor: """Average or sum over dimensions of :attr:`tensor`. :attr:`average_axes` and :attr:`sum_axes` must be mutually exclusive. That diff --git a/texar/losses/mle_losses.py b/texar/losses/mle_losses.py index a5d9a9a6d..b2c4a5602 100644 --- a/texar/losses/mle_losses.py +++ b/texar/losses/mle_losses.py @@ -21,6 +21,8 @@ from texar.losses.losses_utils import mask_and_reduce, reduce_dimensions from texar.utils import shapes +from typing import Optional, Callable, Union, Tuple, Any + # pylint: disable=invalid-name, not-context-manager, protected-access, # pylint: disable=too-many-arguments @@ -33,15 +35,16 @@ ] -def sequence_softmax_cross_entropy(labels, - logits, - sequence_length, - average_across_batch=True, - average_across_timesteps=False, - sum_over_batch=False, - sum_over_timesteps=True, - time_major=False, - stop_gradient_to_label=False): +def sequence_softmax_cross_entropy( + labels: torch.Tensor, + logits: torch.Tensor, + sequence_length: torch.Tensor, + average_across_batch: bool = True, + average_across_timesteps: bool = False, + sum_over_batch: bool = False, + sum_over_timesteps: bool = True, + time_major: bool = False, + stop_gradient_to_label: bool = False) -> torch.Tensor: """Computes softmax cross entropy for each time step of sequence predictions. @@ -110,14 +113,15 @@ def sequence_softmax_cross_entropy(labels, return losses -def sequence_sparse_softmax_cross_entropy(labels, - logits, - sequence_length, - average_across_batch=True, - average_across_timesteps=False, - sum_over_batch=False, - sum_over_timesteps=True, - time_major=False): +def sequence_sparse_softmax_cross_entropy( + labels: torch.Tensor, + logits: torch.Tensor, + sequence_length: torch.Tensor, + average_across_batch: bool = True, + average_across_timesteps: bool = False, + sum_over_batch: bool = False, + sum_over_timesteps: bool = True, + time_major: bool = False) -> torch.Tensor: """Computes sparse softmax cross entropy for each time step of sequence predictions. @@ -197,17 +201,18 @@ def sequence_sparse_softmax_cross_entropy(labels, return losses -def sequence_sigmoid_cross_entropy(labels, - logits, - sequence_length, - average_across_batch=True, - average_across_timesteps=False, - average_across_classes=True, - sum_over_batch=False, - sum_over_timesteps=True, - sum_over_classes=False, - time_major=False, - stop_gradient_to_label=False): +def sequence_sigmoid_cross_entropy( + labels: torch.Tensor, + logits: torch.Tensor, + sequence_length: torch.Tensor, + average_across_batch: bool = True, + average_across_timesteps: bool = False, + average_across_classes: bool = True, + sum_over_batch: bool = False, + sum_over_timesteps: bool = True, + sum_over_classes: bool = False, + time_major: bool = False, + stop_gradient_to_label: bool = False) -> torch.Tensor: """Computes sigmoid cross entropy for each time step of sequence predictions. @@ -288,13 +293,14 @@ class dimension. Must not set `average_across_classes` return losses -def binary_sigmoid_cross_entropy(pos_logits=None, - neg_logits=None, - average_across_batch=True, - average_across_classes=True, - sum_over_batch=False, - sum_over_classes=False, - return_pos_neg_losses=False): +def binary_sigmoid_cross_entropy( + pos_logits: Optional[torch.Tensor] = None, + neg_logits: Optional[torch.Tensor] = None, + average_across_batch: bool = True, + average_across_classes: bool = True, + sum_over_batch: bool = False, + sum_over_classes: bool = False, + return_pos_neg_losses: bool = False) -> torch.Tensor: """Computes sigmoid cross entropy of binary predictions. Args: @@ -365,14 +371,15 @@ class dimension. Must not set `average_across_classes` return loss -def binary_sigmoid_cross_entropy_with_clas(clas_fn, - pos_inputs=None, - neg_inputs=None, - average_across_batch=True, - average_across_classes=True, - sum_over_batch=False, - sum_over_classes=False, - return_pos_neg_losses=False): +def binary_sigmoid_cross_entropy_with_clas( + clas_fn: Callable[[Any], Union[torch.Tensor, Tuple[torch.Tensor]]], + pos_inputs: Any = None, + neg_inputs: Any = None, + average_across_batch: bool = True, + average_across_classes: bool = True, + sum_over_batch: bool = False, + sum_over_classes: bool = False, + return_pos_neg_losses: bool = False) -> torch.Tensor: """Computes sigmoid cross entropy of binary classifier. .. role:: python(code) diff --git a/texar/losses/pg_losses.py b/texar/losses/pg_losses.py index c6a193558..96953b840 100644 --- a/texar/losses/pg_losses.py +++ b/texar/losses/pg_losses.py @@ -21,6 +21,8 @@ from texar.losses.losses_utils import mask_and_reduce from texar.utils.shapes import get_rank +from typing import Optional + # pylint: disable=too-many-arguments, protected-access __all__ = [ @@ -29,19 +31,19 @@ ] -def pg_loss_with_logits(actions, - logits, - advantages, - rank=None, - batched=False, - sequence_length=None, - average_across_batch=True, - average_across_timesteps=False, - average_across_remaining=False, - sum_over_batch=False, - sum_over_timesteps=True, - sum_over_remaining=True, - time_major=False): +def pg_loss_with_logits(actions: torch.Tensor, + logits: torch.Tensor, + advantages: torch.Tensor, + rank: Optional[int] = None, + batched: bool = False, + sequence_length: Optional[torch.Tensor] = None, + average_across_batch: bool = True, + average_across_timesteps: bool = False, + average_across_remaining: bool = False, + sum_over_batch: bool = False, + sum_over_timesteps: bool = True, + sum_over_remaining: bool = True, + time_major: bool = False) -> torch.Tensor: """Policy gradient loss with logits. Used for discrete actions. `pg_loss = reduce( advantages * -log_prob( actions ) )`, @@ -134,18 +136,18 @@ def pg_loss_with_logits(actions, time_major=time_major) -def pg_loss_with_log_probs(log_probs, - advantages, - rank=None, - batched=False, - sequence_length=None, - average_across_batch=True, - average_across_timesteps=False, - average_across_remaining=False, - sum_over_batch=False, - sum_over_timesteps=True, - sum_over_remaining=True, - time_major=False): +def pg_loss_with_log_probs(log_probs: torch.Tensor, + advantages: torch.Tensor, + rank: Optional[int] = None, + batched: bool = False, + sequence_length: Optional[torch.Tensor] = None, + average_across_batch: bool = True, + average_across_timesteps: bool = False, + average_across_remaining: bool = False, + sum_over_batch: bool = False, + sum_over_timesteps: bool = True, + sum_over_remaining: bool = True, + time_major: bool = False) -> torch.Tensor: """Policy gradient loss with log probs of actions. `pg_loss = reduce( advantages * -log_probs )`, From 58283f866cb269b8a3b81b8c7ecffadaedc2e81a Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Thu, 16 May 2019 16:44:04 -0400 Subject: [PATCH 23/24] fix type annotation --- texar/losses/entropy.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/texar/losses/entropy.py b/texar/losses/entropy.py index 48f1180f0..744c63080 100644 --- a/texar/losses/entropy.py +++ b/texar/losses/entropy.py @@ -21,6 +21,8 @@ from texar.losses.losses_utils import mask_and_reduce, reduce_dimensions from texar.utils.shapes import get_rank +from typing import Optional + # pylint: disable=too-many-arguments __all__ = [ @@ -37,7 +39,7 @@ def _get_entropy(logits: torch.Tensor) -> torch.Tensor: def entropy_with_logits(logits: torch.Tensor, - rank: int = None, + rank: Optional[int] = None, average_across_batch: bool = True, average_across_remaining: bool = False, sum_over_batch: bool = False, @@ -114,7 +116,7 @@ def entropy_with_logits(logits: torch.Tensor, def sequence_entropy_with_logits(logits: torch.Tensor, - rank: int = None, + rank: Optional[int] = None, sequence_length: torch.Tensor = None, average_across_batch: bool = True, average_across_timesteps: bool = False, From a048a68639ecb43724ae7136529c877c08023471 Mon Sep 17 00:00:00 2001 From: Zecong Hu Date: Fri, 17 May 2019 10:04:29 +0800 Subject: [PATCH 24/24] Fix type-checking issues --- texar/losses/adv_losses.py | 13 ++++---- texar/losses/adv_losses_test.py | 2 ++ texar/losses/entropy.py | 25 ++++++++------- texar/losses/losses_utils.py | 21 ++++++------ texar/losses/mle_losses.py | 57 ++++++++++++++++----------------- texar/losses/mle_losses_test.py | 3 +- texar/losses/pg_losses.py | 8 ++--- texar/losses/pg_losses_test.py | 9 +++--- 8 files changed, 70 insertions(+), 68 deletions(-) diff --git a/texar/losses/adv_losses.py b/texar/losses/adv_losses.py index 7a0fa4d24..6e1c3587f 100644 --- a/texar/losses/adv_losses.py +++ b/texar/losses/adv_losses.py @@ -15,19 +15,18 @@ Adversarial losses. """ +from typing import Callable, Tuple + import torch import torch.nn.functional as F -from typing import List, Union, Callable, Tuple +from texar.utils.types import MaybeTuple def binary_adversarial_losses( - real_data: Union[torch.Tensor, List[Union[int, float]]], - fake_data: Union[torch.Tensor, List[Union[int, float]]], - discriminator_fn: Callable[[Union[torch.Tensor, - List[Union[int, float]]]], - Union[torch.Tensor, List[torch.Tensor], - Tuple[torch.Tensor]]], + real_data: torch.Tensor, + fake_data: torch.Tensor, + discriminator_fn: Callable[[torch.Tensor], MaybeTuple[torch.Tensor]], mode: str = "max_real") -> Tuple[torch.Tensor, torch.Tensor]: """Computes adversarial losses of real/fake binary discrimination game. diff --git a/texar/losses/adv_losses_test.py b/texar/losses/adv_losses_test.py index 0165784a8..1cbceede5 100644 --- a/texar/losses/adv_losses_test.py +++ b/texar/losses/adv_losses_test.py @@ -7,7 +7,9 @@ # pylint: disable=invalid-name import unittest + import torch + import texar as tx diff --git a/texar/losses/entropy.py b/texar/losses/entropy.py index 744c63080..1cc71b4ea 100644 --- a/texar/losses/entropy.py +++ b/texar/losses/entropy.py @@ -15,14 +15,14 @@ Various entropies. """ +from typing import Optional + import torch import torch.nn.functional as F from texar.losses.losses_utils import mask_and_reduce, reduce_dimensions from texar.utils.shapes import get_rank -from typing import Optional - # pylint: disable=too-many-arguments __all__ = [ @@ -115,16 +115,17 @@ def entropy_with_logits(logits: torch.Tensor, return entropy -def sequence_entropy_with_logits(logits: torch.Tensor, - rank: Optional[int] = None, - sequence_length: torch.Tensor = None, - average_across_batch: bool = True, - average_across_timesteps: bool = False, - average_across_remaining: bool = False, - sum_over_batch: bool = False, - sum_over_timesteps: bool = True, - sum_over_remaining: bool = True, - time_major: bool = False) -> torch.Tensor: +def sequence_entropy_with_logits( + logits: torch.Tensor, + rank: Optional[int] = None, + sequence_length: Optional[torch.LongTensor] = None, + average_across_batch: bool = True, + average_across_timesteps: bool = False, + average_across_remaining: bool = False, + sum_over_batch: bool = False, + sum_over_timesteps: bool = True, + sum_over_remaining: bool = True, + time_major: bool = False) -> torch.Tensor: """Shannon entropy given logits. Args: diff --git a/texar/losses/losses_utils.py b/texar/losses/losses_utils.py index 121d6683f..e48cd34c7 100644 --- a/texar/losses/losses_utils.py +++ b/texar/losses/losses_utils.py @@ -15,14 +15,15 @@ Various utilities for losses. """ -import torch +# pylint: disable=invalid-name, not-context-manager, protected-access, +# pylint: disable=too-many-arguments -from texar.utils.shapes import transpose_batch_time, mask_sequences +from typing import Optional -from typing import Optional, Union, List +import torch -# pylint: disable=invalid-name, not-context-manager, protected-access, -# pylint: disable=too-many-arguments +from texar.utils.shapes import mask_sequences, transpose_batch_time +from texar.utils.types import MaybeList __all__ = [ "mask_and_reduce", @@ -32,7 +33,7 @@ def mask_and_reduce(sequence: torch.Tensor, - sequence_length: torch.Tensor, + sequence_length: Optional[torch.LongTensor], rank: int = 2, average_across_batch: bool = True, average_across_timesteps: bool = False, @@ -40,7 +41,7 @@ def mask_and_reduce(sequence: torch.Tensor, sum_over_batch: bool = False, sum_over_timesteps: bool = True, sum_over_remaining: bool = True, - dtype: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, time_major: bool = False) -> torch.Tensor: """Masks out sequence entries that are beyond the respective sequence lengths, and reduces (average or sum) away dimensions. @@ -127,7 +128,7 @@ def mask_and_reduce(sequence: torch.Tensor, def reduce_batch_time(sequence: torch.Tensor, - sequence_length: torch.Tensor, + sequence_length: Optional[torch.LongTensor], average_across_batch: bool = True, average_across_timesteps: bool = False, sum_over_batch: bool = False, @@ -163,8 +164,8 @@ def reduce_batch_time(sequence: torch.Tensor, def reduce_dimensions(tensor: torch.Tensor, - average_axes: Optional[Union[int, List[int]]] = None, - sum_axes: Optional[Union[int, List[int]]] = None, + average_axes: Optional[MaybeList[int]] = None, + sum_axes: Optional[MaybeList[int]] = None, keepdims: Optional[bool] = None) -> torch.Tensor: """Average or sum over dimensions of :attr:`tensor`. diff --git a/texar/losses/mle_losses.py b/texar/losses/mle_losses.py index b2c4a5602..3b094b266 100644 --- a/texar/losses/mle_losses.py +++ b/texar/losses/mle_losses.py @@ -15,16 +15,17 @@ Various losses """ +# pylint: disable=invalid-name, not-context-manager, protected-access, +# pylint: disable=too-many-arguments + +from typing import Callable, Optional, Tuple, Union + import torch import torch.nn.functional as F from texar.losses.losses_utils import mask_and_reduce, reduce_dimensions from texar.utils import shapes - -from typing import Optional, Callable, Union, Tuple, Any - -# pylint: disable=invalid-name, not-context-manager, protected-access, -# pylint: disable=too-many-arguments +from texar.utils.types import MaybeTuple __all__ = [ "sequence_softmax_cross_entropy", @@ -38,7 +39,7 @@ def sequence_softmax_cross_entropy( labels: torch.Tensor, logits: torch.Tensor, - sequence_length: torch.Tensor, + sequence_length: torch.LongTensor, average_across_batch: bool = True, average_across_timesteps: bool = False, sum_over_batch: bool = False, @@ -84,7 +85,6 @@ def sequence_softmax_cross_entropy( (default), they must have shape `[batch_size, max_time, ...]`. stop_gradient_to_label (bool): If set, gradient propagation to :attr:`labels` will be disabled. - name (str, optional): A name for the operation. Returns: A Tensor containing the loss, of rank 0, 1, or 2 depending on the @@ -100,7 +100,7 @@ def sequence_softmax_cross_entropy( if stop_gradient_to_label: labels = labels.detach() - losses = torch.sum(-labels.type(logits.dtype)*F.log_softmax(logits, -1), -1) + losses = torch.sum(-labels.type(logits.dtype) * F.log_softmax(logits, -1), -1) losses = mask_and_reduce(losses, sequence_length, @@ -116,7 +116,7 @@ def sequence_softmax_cross_entropy( def sequence_sparse_softmax_cross_entropy( labels: torch.Tensor, logits: torch.Tensor, - sequence_length: torch.Tensor, + sequence_length: torch.LongTensor, average_across_batch: bool = True, average_across_timesteps: bool = False, sum_over_batch: bool = False, @@ -156,7 +156,6 @@ def sequence_sparse_softmax_cross_entropy( :attr:`labels` and :attr:`logits` must have shape `[max_time, batch_size, ...]`. If `False` (default), they must have shape `[batch_size, max_time, ...]`. - name (str, optional): A name for the operation. Returns: A Tensor containing the loss, of rank 0, 1, or 2 depending on the @@ -204,7 +203,7 @@ def sequence_sparse_softmax_cross_entropy( def sequence_sigmoid_cross_entropy( labels: torch.Tensor, logits: torch.Tensor, - sequence_length: torch.Tensor, + sequence_length: torch.LongTensor, average_across_batch: bool = True, average_across_timesteps: bool = False, average_across_classes: bool = True, @@ -258,7 +257,6 @@ class dimension. Must not set `average_across_classes` (default), they must have shape `[batch_size, max_time, ...]`. stop_gradient_to_label (bool): If set, gradient propagation to :attr:`labels` will be disabled. - name (str, optional): A name for the operation. Returns: A Tensor containing the loss, of rank 0, 1, or 2 depending on the @@ -300,7 +298,8 @@ def binary_sigmoid_cross_entropy( average_across_classes: bool = True, sum_over_batch: bool = False, sum_over_classes: bool = False, - return_pos_neg_losses: bool = False) -> torch.Tensor: + return_pos_neg_losses: bool = False) \ + -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: """Computes sigmoid cross entropy of binary predictions. Args: @@ -324,7 +323,6 @@ class dimension. Must not set `average_across_classes` :attr:`logits` is a 2D Tensor. return_pos_neg_losses (bool): If set, additionally returns the losses on :attr:`pos_logits` and :attr:`neg_logits`, respectively. - name (str, optional): A name for the operation. Returns: By default, a Tensor containing the loss, of rank 0, 1, or 2 depending @@ -343,25 +341,26 @@ class dimension. Must not set `average_across_classes` `neg_loss` is the loss on `neg_logits` only. They have `loss = pos_loss + neg_loss`. """ - average_axes, sum_axes = [], [] - average_axes += [0] if average_across_batch else [] + average_axes = [0] if average_across_batch else [] average_axes += [1] if average_across_classes else [] - sum_axes += [0] if sum_over_batch else [] + sum_axes = [0] if sum_over_batch else [] sum_axes += [1] if sum_over_classes else [] - pos_loss = 0 if pos_logits is not None: pos_loss = F.binary_cross_entropy_with_logits( pos_logits, torch.ones_like(pos_logits), reduction='none') pos_loss = reduce_dimensions(pos_loss, average_axes, sum_axes) + else: + pos_loss = 0 # type: ignore - neg_loss = 0 if neg_logits is not None: neg_loss = F.binary_cross_entropy_with_logits( neg_logits, torch.zeros_like(neg_logits), reduction='none') neg_loss = reduce_dimensions(neg_loss, average_axes, sum_axes) + else: + neg_loss = 0 # type: ignore loss = pos_loss + neg_loss @@ -372,14 +371,15 @@ class dimension. Must not set `average_across_classes` def binary_sigmoid_cross_entropy_with_clas( - clas_fn: Callable[[Any], Union[torch.Tensor, Tuple[torch.Tensor]]], - pos_inputs: Any = None, - neg_inputs: Any = None, + clas_fn: Callable[[torch.Tensor], MaybeTuple[torch.Tensor]], + pos_inputs: Optional[torch.Tensor] = None, + neg_inputs: Optional[torch.Tensor] = None, average_across_batch: bool = True, average_across_classes: bool = True, sum_over_batch: bool = False, sum_over_classes: bool = False, - return_pos_neg_losses: bool = False) -> torch.Tensor: + return_pos_neg_losses: bool = False) \ + -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: """Computes sigmoid cross entropy of binary classifier. .. role:: python(code) @@ -410,7 +410,6 @@ class dimension. Must not set `average_across_classes` :attr:`logits` is a 2D Tensor. return_pos_neg_losses (bool): If set, additionally returns the losses on :attr:`pos_logits` and :attr:`neg_logits`, respectively. - name (str, optional): A name for the operation. Returns: By default, a Tensor containing the loss, of rank 0, 1, or 2 depending @@ -431,15 +430,13 @@ class dimension. Must not set `average_across_classes` """ pos_logits = None if pos_inputs is not None: - pos_logits = clas_fn(pos_inputs) - if isinstance(pos_logits, (list, tuple)): - pos_logits = pos_logits[0] + out = clas_fn(pos_inputs) + pos_logits = out[0] if isinstance(out, (list, tuple)) else out neg_logits = None if neg_inputs is not None: - neg_logits = clas_fn(neg_inputs) - if isinstance(neg_logits, (list, tuple)): - neg_logits = neg_logits[0] + out = clas_fn(neg_inputs) + neg_logits = out[0] if isinstance(out, (list, tuple)) else out return binary_sigmoid_cross_entropy( pos_logits=pos_logits, diff --git a/texar/losses/mle_losses_test.py b/texar/losses/mle_losses_test.py index 30aeca484..aef0c7575 100644 --- a/texar/losses/mle_losses_test.py +++ b/texar/losses/mle_losses_test.py @@ -7,10 +7,11 @@ # pylint: disable=invalid-name import unittest + import torch import torch.nn.functional as F -import texar as tx +import texar as tx from texar.utils.shapes import get_rank diff --git a/texar/losses/pg_losses.py b/texar/losses/pg_losses.py index 96953b840..7ed118c6b 100644 --- a/texar/losses/pg_losses.py +++ b/texar/losses/pg_losses.py @@ -15,14 +15,14 @@ Various loss functions for policy gradients. """ +from typing import Optional + import torch import torch.nn.functional as F from texar.losses.losses_utils import mask_and_reduce from texar.utils.shapes import get_rank -from typing import Optional - # pylint: disable=too-many-arguments, protected-access __all__ = [ @@ -36,7 +36,7 @@ def pg_loss_with_logits(actions: torch.Tensor, advantages: torch.Tensor, rank: Optional[int] = None, batched: bool = False, - sequence_length: Optional[torch.Tensor] = None, + sequence_length: Optional[torch.LongTensor] = None, average_across_batch: bool = True, average_across_timesteps: bool = False, average_across_remaining: bool = False, @@ -140,7 +140,7 @@ def pg_loss_with_log_probs(log_probs: torch.Tensor, advantages: torch.Tensor, rank: Optional[int] = None, batched: bool = False, - sequence_length: Optional[torch.Tensor] = None, + sequence_length: Optional[torch.LongTensor] = None, average_across_batch: bool = True, average_across_timesteps: bool = False, average_across_remaining: bool = False, diff --git a/texar/losses/pg_losses_test.py b/texar/losses/pg_losses_test.py index c93e8c6ae..038476ffc 100644 --- a/texar/losses/pg_losses_test.py +++ b/texar/losses/pg_losses_test.py @@ -7,9 +7,10 @@ # pylint: disable=invalid-name import unittest + import torch -import texar as tx +import texar as tx from texar.utils.shapes import get_rank @@ -20,9 +21,9 @@ class PGLossesTest(unittest.TestCase): def setUp(self): self._batch_size = 64 self._max_time = 16 - self._d1 = 32 - self._d2 = 32 - self._d3 = 32 + self._d1 = 3 # use smaller values to speedup testing + self._d2 = 4 + self._d3 = 5 self._num_classes = 10 self._actions_batch = torch.ones(self._batch_size, self._max_time, self._d1, self._d2, self._d3,