Skip to content

Commit

Permalink
train_models/mtcnn_model.py: Use OHEM for classification only
Browse files Browse the repository at this point in the history
Acoording to the original paper, on-line hard example
mining technique is used in face/non-face classification
task only.

Signed-off-by: Ilya Nelkenbaum <[email protected]>
  • Loading branch information
Ilya Nelkenbaum committed Oct 29, 2017
1 parent c977fb9 commit 7c1a5ea
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions train_models/mtcnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def bbox_ohem(bbox_pred,bbox_target,label):
square_error = tf.reduce_sum(square_error,axis=1)
#keep_num scalar
num_valid = tf.reduce_sum(valid_inds)
keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)
#keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)
keep_num = tf.cast(num_valid, dtype=tf.int32)
#keep valid index square_error
square_error = square_error*valid_inds
_, k_index = tf.nn.top_k(square_error, k=keep_num)
Expand All @@ -92,7 +93,8 @@ def landmark_ohem(landmark_pred,landmark_target,label):
square_error = tf.square(landmark_pred-landmark_target)
square_error = tf.reduce_sum(square_error,axis=1)
num_valid = tf.reduce_sum(valid_inds)
keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)
#keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)
keep_num = tf.cast(num_valid, dtype=tf.int32)
square_error = square_error*valid_inds
_, k_index = tf.nn.top_k(square_error, k=keep_num)
square_error = tf.gather(square_error, k_index)
Expand Down

0 comments on commit 7c1a5ea

Please sign in to comment.