Skip to content

Commit

Permalink
Added possibility to use fixed image standardization
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsandberg committed Apr 1, 2018
1 parent a7590ff commit b4e4397
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/train_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def main(args):
learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, control_placeholder, global_step,
total_loss, train_op, summary_op, summary_writer, regularization_losses, args.learning_rate_schedule_file,
stat, cross_entropy_mean, accuracy, learning_rate,
prelogits, prelogits_center_loss, args.random_rotate, args.random_crop, args.random_flip, prelogits_norm, args.prelogits_hist_max)
prelogits, prelogits_center_loss, args.random_rotate, args.random_crop, args.random_flip, prelogits_norm, args.prelogits_hist_max, args.use_fixed_image_standardization)
stat['time_train'][epoch-1] = time.time() - t

if not cont:
Expand All @@ -248,7 +248,7 @@ def main(args):
if len(val_image_list)>0 and ((epoch-1) % args.validate_every_n_epochs == args.validate_every_n_epochs-1 or epoch==args.max_nrof_epochs):
validate(args, sess, epoch, val_image_list, val_label_list, enqueue_op, image_paths_placeholder, labels_placeholder, control_placeholder,
phase_train_placeholder, batch_size_placeholder,
stat, total_loss, regularization_losses, cross_entropy_mean, accuracy, args.validate_every_n_epochs)
stat, total_loss, regularization_losses, cross_entropy_mean, accuracy, args.validate_every_n_epochs, args.use_fixed_image_standardization)
stat['time_validate'][epoch-1] = time.time() - t

# Save variables and the metagraph if it doesn't exist already
Expand All @@ -259,7 +259,7 @@ def main(args):
if args.lfw_dir:
evaluate(sess, enqueue_op, image_paths_placeholder, labels_placeholder, phase_train_placeholder, batch_size_placeholder, control_placeholder,
embeddings, label_batch, lfw_paths, actual_issame, args.lfw_batch_size, args.lfw_nrof_folds, log_dir, step, summary_writer, stat, epoch,
args.lfw_distance_metric, args.lfw_subtract_mean, args.lfw_use_flipped_images)
args.lfw_distance_metric, args.lfw_subtract_mean, args.lfw_use_flipped_images, args.use_fixed_image_standardization)
stat['time_evaluate'][epoch-1] = time.time() - t

print('Saving statistics')
Expand Down Expand Up @@ -304,7 +304,7 @@ def train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_o
learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, control_placeholder, step,
loss, train_op, summary_op, summary_writer, reg_losses, learning_rate_schedule_file,
stat, cross_entropy_mean, accuracy,
learning_rate, prelogits, prelogits_center_loss, random_rotate, random_crop, random_flip, prelogits_norm, prelogits_hist_max):
learning_rate, prelogits, prelogits_center_loss, random_rotate, random_crop, random_flip, prelogits_norm, prelogits_hist_max, use_fixed_image_standardization):
batch_number = 0

if args.learning_rate>0.0:
Expand All @@ -322,7 +322,7 @@ def train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_o
# Enqueue one epoch of image paths and labels
labels_array = np.expand_dims(np.array(label_epoch),1)
image_paths_array = np.expand_dims(np.array(image_epoch),1)
control_value = facenet.RANDOM_ROTATE * random_rotate + facenet.RANDOM_CROP * random_crop + facenet.RANDOM_FLIP * random_flip
control_value = facenet.RANDOM_ROTATE * random_rotate + facenet.RANDOM_CROP * random_crop + facenet.RANDOM_FLIP * random_flip + facenet.FIXED_STANDARDIZATION * use_fixed_image_standardization
control_array = np.ones_like(labels_array) * control_value
sess.run(enqueue_op, {image_paths_placeholder: image_paths_array, labels_placeholder: labels_array, control_placeholder: control_array})

Expand Down Expand Up @@ -362,7 +362,7 @@ def train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_o

def validate(args, sess, epoch, image_list, label_list, enqueue_op, image_paths_placeholder, labels_placeholder, control_placeholder,
phase_train_placeholder, batch_size_placeholder,
stat, loss, regularization_losses, cross_entropy_mean, accuracy, validate_every_n_epochs):
stat, loss, regularization_losses, cross_entropy_mean, accuracy, validate_every_n_epochs, use_fixed_image_standardization):

print('Running forward pass on validation set')

Expand All @@ -372,7 +372,7 @@ def validate(args, sess, epoch, image_list, label_list, enqueue_op, image_paths_
# Enqueue one epoch of image paths and labels
labels_array = np.expand_dims(np.array(label_list[:nrof_images]),1)
image_paths_array = np.expand_dims(np.array(image_list[:nrof_images]),1)
control_array = np.zeros_like(labels_array, np.int32)
control_array = np.ones_like(labels_array, np.int32)*facenet.FIXED_STANDARDIZATION * use_fixed_image_standardization
sess.run(enqueue_op, {image_paths_placeholder: image_paths_array, labels_placeholder: labels_array, control_placeholder: control_array})

loss_array = np.zeros((nrof_batches,), np.float32)
Expand Down Expand Up @@ -402,7 +402,7 @@ def validate(args, sess, epoch, image_list, label_list, enqueue_op, image_paths_


def evaluate(sess, enqueue_op, image_paths_placeholder, labels_placeholder, phase_train_placeholder, batch_size_placeholder, control_placeholder,
embeddings, labels, image_paths, actual_issame, batch_size, nrof_folds, log_dir, step, summary_writer, stat, epoch, distance_metric, subtract_mean, use_flipped_images):
embeddings, labels, image_paths, actual_issame, batch_size, nrof_folds, log_dir, step, summary_writer, stat, epoch, distance_metric, subtract_mean, use_flipped_images, use_fixed_image_standardization):
start_time = time.time()
# Run forward pass to calculate embeddings
print('Runnning forward pass on LFW images')
Expand All @@ -413,11 +413,12 @@ def evaluate(sess, enqueue_op, image_paths_placeholder, labels_placeholder, phas
nrof_images = nrof_embeddings * nrof_flips
labels_array = np.expand_dims(np.arange(0,nrof_images),1)
image_paths_array = np.expand_dims(np.repeat(np.array(image_paths),nrof_flips),1)
control_array = np.zeros_like(labels_array, np.int32)
if use_fixed_image_standardization:
control_array += np.ones_like(labels_array)*facenet.FIXED_STANDARDIZATION
if use_flipped_images:
# Flip every second image
control_array = (labels_array % 2)*16
else:
control_array = np.zeros_like(labels_array)
control_array += (labels_array % 2)*facenet.FLIP
sess.run(enqueue_op, {image_paths_placeholder: image_paths_array, labels_placeholder: labels_array, control_placeholder: control_array})

embedding_size = int(embeddings.get_shape()[1])
Expand Down Expand Up @@ -516,6 +517,8 @@ def parse_arguments(argv):
help='Performs random horizontal flipping of training images.', action='store_true')
parser.add_argument('--random_rotate',
help='Performs random rotations of training images.', action='store_true')
parser.add_argument('--use_fixed_image_standardization',
help='Performs fixed standardization of images.', action='store_true')
parser.add_argument('--keep_probability', type=float,
help='Keep probability of dropout for the fully connected layer(s).', default=1.0)
parser.add_argument('--weight_decay', type=float,
Expand Down

0 comments on commit b4e4397

Please sign in to comment.