Skip to content

Commit

Permalink
更加严格
Browse files Browse the repository at this point in the history
  • Loading branch information
bojone authored May 7, 2022
1 parent 9271dcf commit b465d50
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions bert4keras/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,9 @@ def multilabel_categorical_crossentropy(y_true, y_pred):
n_mask = K.less_equal(y_true, 1 - K.epsilon())
p_mask = K.greater_equal(y_true, K.epsilon())
y_true = K.clip(y_true, K.epsilon(), 1 - K.epsilon())
y_neg = K.switch(n_mask, y_pred, -y_pred) + K.log(1 - y_true)
y_pos = K.switch(p_mask, -y_pred, y_pred) + K.log(y_true)
infs = K.zeros_like(y_pred) + K.infinity()
y_neg = K.switch(n_mask, y_pred, -infs) + K.log(1 - y_true)
y_pos = K.switch(p_mask, -y_pred, -infs) + K.log(y_true)
zeros = K.zeros_like(y_pred[..., :1])
y_neg = K.concatenate([y_neg, zeros], axis=-1)
y_pos = K.concatenate([y_pos, zeros], axis=-1)
Expand Down

0 comments on commit b465d50

Please sign in to comment.