Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gpengzhi committed May 16, 2019
1 parent b970e2d commit 0db4ad4
Show file tree
Hide file tree
Showing 10 changed files with 17 additions and 69 deletions.
6 changes: 1 addition & 5 deletions texar/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,10 +15,6 @@
Modules of texar losses.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=wildcard-import

from texar.losses.losses_utils import *
Expand Down
14 changes: 6 additions & 8 deletions texar/losses/adv_losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,11 +14,9 @@
"""
Adversarial losses.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn.functional as F


def binary_adversarial_losses(real_data,
Expand Down Expand Up @@ -58,14 +56,14 @@ def binary_adversarial_losses(real_data,
real_logits = discriminator_fn(real_data)
if isinstance(real_logits, (list, tuple)):
real_logits = real_logits[0]
real_loss = torch.nn.BCEWithLogitsLoss(reduction='mean')
real_loss = real_loss(real_logits, torch.ones_like(real_logits))
real_loss = F.binary_cross_entropy_with_logits(
real_logits, torch.ones_like(real_logits))

fake_logits = discriminator_fn(fake_data)
if isinstance(fake_logits, (list, tuple)):
fake_logits = fake_logits[0]
fake_loss = torch.nn.BCEWithLogitsLoss(reduction='mean')
fake_loss = fake_loss(fake_logits, torch.zeros_like(fake_logits))
fake_loss = F.binary_cross_entropy_with_logits(
fake_logits, torch.zeros_like(fake_logits))

d_loss = real_loss + fake_loss

Expand Down
5 changes: 0 additions & 5 deletions texar/losses/adv_losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
Unit tests for adv_losses.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

# pylint: disable=invalid-name

import unittest
Expand Down
6 changes: 1 addition & 5 deletions texar/losses/entropy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,10 +15,6 @@
Various entropies.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn.functional as F

Expand Down
5 changes: 0 additions & 5 deletions texar/losses/entropy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
Unit tests for entropy.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

# pylint: disable=invalid-name

import unittest
Expand Down
7 changes: 1 addition & 6 deletions texar/losses/losses_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,11 +15,6 @@
Various utilities for losses.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import torch

from texar.utils.shapes import transpose_batch_time, mask_sequences
Expand Down
27 changes: 7 additions & 20 deletions texar/losses/mle_losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,10 +15,6 @@
Various losses
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn.functional as F

Expand All @@ -45,8 +41,7 @@ def sequence_softmax_cross_entropy(labels,
sum_over_batch=False,
sum_over_timesteps=True,
time_major=False,
stop_gradient_to_label=False,
name=None):
stop_gradient_to_label=False):
"""Computes softmax cross entropy for each time step of sequence
predictions.
Expand Down Expand Up @@ -122,8 +117,7 @@ def sequence_sparse_softmax_cross_entropy(labels,
average_across_timesteps=False,
sum_over_batch=False,
sum_over_timesteps=True,
time_major=False,
name=None):
time_major=False):
"""Computes sparse softmax cross entropy for each time step of sequence
predictions.
Expand Down Expand Up @@ -213,8 +207,7 @@ def sequence_sigmoid_cross_entropy(labels,
sum_over_timesteps=True,
sum_over_classes=False,
time_major=False,
stop_gradient_to_label=False,
name=None):
stop_gradient_to_label=False):
"""Computes sigmoid cross entropy for each time step of sequence
predictions.
Expand Down Expand Up @@ -281,9 +274,6 @@ class dimension. Must not set `average_across_classes`
losses = losses(logits, labels.type(logits.dtype))

rank = shapes.get_rank(logits) or shapes.get_rank(labels)
if rank is None:
raise ValueError(
'Cannot determine the rank of `logits` or `labels`.')

losses = mask_and_reduce(losses,
sequence_length,
Expand All @@ -305,8 +295,7 @@ def binary_sigmoid_cross_entropy(pos_logits=None,
average_across_classes=True,
sum_over_batch=False,
sum_over_classes=False,
return_pos_neg_losses=False,
name=None):
return_pos_neg_losses=False):
"""Computes sigmoid cross entropy of binary predictions.
Args:
Expand Down Expand Up @@ -384,8 +373,7 @@ def binary_sigmoid_cross_entropy_with_clas(clas_fn,
average_across_classes=True,
sum_over_batch=False,
sum_over_classes=False,
return_pos_neg_losses=False,
name=None):
return_pos_neg_losses=False):
"""Computes sigmoid cross entropy of binary classifier.
.. role:: python(code)
Expand Down Expand Up @@ -454,5 +442,4 @@ class dimension. Must not set `average_across_classes`
average_across_classes=average_across_classes,
sum_over_batch=sum_over_batch,
sum_over_classes=sum_over_classes,
return_pos_neg_losses=return_pos_neg_losses,
name=name)
return_pos_neg_losses=return_pos_neg_losses)
5 changes: 0 additions & 5 deletions texar/losses/mle_losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
Unit tests for mle losses.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

# pylint: disable=invalid-name

import unittest
Expand Down
6 changes: 1 addition & 5 deletions texar/losses/pg_losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,10 +15,6 @@
Various loss functions for policy gradients.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn.functional as F

Expand Down
5 changes: 0 additions & 5 deletions texar/losses/pg_losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
Unit tests for pg losses.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

# pylint: disable=invalid-name

import unittest
Expand Down

0 comments on commit 0db4ad4

Please sign in to comment.