Skip to content

Commit

Permalink
Update joint_ctc_cross_entropy.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sooftware authored Apr 20, 2021
1 parent 64ebd95 commit 36decbe
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions kospeech/criterion/joint_ctc_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,29 @@


class JointCTCCrossEntropyLoss(nn.Module):
"""
Privides Joint CTC-CrossEntropy Loss function
Args:
num_classes (int): the number of classification
ignore_index (int): indexes that are ignored when calculating loss
dim (int): dimension of calculation loss
reduction (str): reduction method [sum, mean] (default: mean)
ctc_weight (float): weight of ctc loss
cross_entropy_weight (float): weight of cross entropy loss
blank_id (int): identification of blank for ctc
"""

def __init__(
self,
num_classes: int, # the number of classfication
ignore_index: int, # indexes that are ignored when calcuating loss
dim: int = -1, # dimension of caculation loss
reduction='mean', # reduction method [sum, mean]
ctc_weight: float = 0.3, # weight of ctc loss
cross_entropy_weight: float = 0.7, # weight of cross entropy loss
blank_id: int = None, # identification of blank token
smoothing: float = 0.1, # ratio of smoothing (confidence = 1.0 - smoothing)
num_classes: int,
ignore_index: int,
dim: int = -1,
reduction='mean',
ctc_weight: float = 0.3,
cross_entropy_weight: float = 0.7,
blank_id: int = None,
smoothing: float = 0.1,
) -> None:
super(JointCTCCrossEntropyLoss, self).__init__()
self.num_classes = num_classes
Expand Down

0 comments on commit 36decbe

Please sign in to comment.