Skip to content

Commit

Permalink
fix type annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
gpengzhi committed May 16, 2019
1 parent 010821b commit 58283f8
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions texar/losses/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from texar.losses.losses_utils import mask_and_reduce, reduce_dimensions
from texar.utils.shapes import get_rank

from typing import Optional

# pylint: disable=too-many-arguments

__all__ = [
Expand All @@ -37,7 +39,7 @@ def _get_entropy(logits: torch.Tensor) -> torch.Tensor:


def entropy_with_logits(logits: torch.Tensor,
rank: int = None,
rank: Optional[int] = None,
average_across_batch: bool = True,
average_across_remaining: bool = False,
sum_over_batch: bool = False,
Expand Down Expand Up @@ -114,7 +116,7 @@ def entropy_with_logits(logits: torch.Tensor,


def sequence_entropy_with_logits(logits: torch.Tensor,
rank: int = None,
rank: Optional[int] = None,
sequence_length: torch.Tensor = None,
average_across_batch: bool = True,
average_across_timesteps: bool = False,
Expand Down

0 comments on commit 58283f8

Please sign in to comment.