Skip to content

Commit

Permalink
Added validation on a subset of the dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsandberg committed Mar 31, 2018
1 parent efc2f94 commit 8ebebe4
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 13 deletions.
17 changes: 9 additions & 8 deletions src/facenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,26 +370,27 @@ def get_image_paths(facedir):
image_paths = [os.path.join(facedir,img) for img in images]
return image_paths

def split_dataset(dataset, split_ratio, mode):
def split_dataset(dataset, split_ratio, min_nrof_images_per_class, mode):
if mode=='SPLIT_CLASSES':
nrof_classes = len(dataset)
class_indices = np.arange(nrof_classes)
np.random.shuffle(class_indices)
split = int(round(nrof_classes*split_ratio))
split = int(round(nrof_classes*(1-split_ratio)))
train_set = [dataset[i] for i in class_indices[0:split]]
test_set = [dataset[i] for i in class_indices[split:-1]]
elif mode=='SPLIT_IMAGES':
train_set = []
test_set = []
min_nrof_images = 2
for cls in dataset:
paths = cls.image_paths
np.random.shuffle(paths)
split = int(round(len(paths)*split_ratio))
if split<min_nrof_images:
continue # Not enough images for test set. Skip class...
train_set.append(ImageClass(cls.name, paths[0:split]))
test_set.append(ImageClass(cls.name, paths[split:-1]))
nrof_images_in_class = len(paths)
split = int(math.floor(nrof_images_in_class*(1-split_ratio)))
if split==nrof_images_in_class:
split = nrof_images_in_class-1
if split>=min_nrof_images_per_class and nrof_images_in_class-split>=1:
train_set.append(ImageClass(cls.name, paths[:split]))
test_set.append(ImageClass(cls.name, paths[split:]))
else:
raise ValueError('Invalid train/test split mode "%s"' % mode)
return train_set, test_set
Expand Down
67 changes: 62 additions & 5 deletions src/train_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,16 @@ def main(args):

np.random.seed(seed=args.seed)
random.seed(args.seed)
train_set = facenet.get_dataset(args.data_dir)
dataset = facenet.get_dataset(args.data_dir)
if args.filter_filename:
train_set = filter_dataset(train_set, os.path.expanduser(args.filter_filename),
dataset = filter_dataset(dataset, os.path.expanduser(args.filter_filename),
args.filter_percentile, args.filter_min_nrof_images_per_class)

if args.validation_set_split_ratio>0.0:
train_set, val_set = facenet.split_dataset(dataset, args.validation_set_split_ratio, args.min_nrof_val_images_per_class, 'SPLIT_IMAGES')
else:
train_set, val_set = dataset, []

nrof_classes = len(train_set)

print('Model directory: %s' % model_dir)
Expand All @@ -91,8 +97,10 @@ def main(args):

# Get a list of image paths and their labels
image_list, label_list = facenet.get_image_paths_and_labels(train_set)
assert len(image_list)>0, 'The dataset should not be empty'
assert len(image_list)>0, 'The training set should not be empty'

val_image_list, val_label_list = facenet.get_image_paths_and_labels(val_set)

# Create a queue that produces indices into the image_list and label_list
labels = ops.convert_to_tensor(label_list, dtype=tf.int32)
range_size = array_ops.shape(labels)[0]
Expand Down Expand Up @@ -127,8 +135,11 @@ def main(args):
image_batch = tf.identity(image_batch, 'input')
label_batch = tf.identity(label_batch, 'label_batch')

print('Total number of classes: %d' % nrof_classes)
print('Total number of examples: %d' % len(image_list))
print('Number of classes in training set: %d' % nrof_classes)
print('Number of examples in training set: %d' % len(image_list))

print('Number of classes in validation set: %d' % len(val_set))
print('Number of examples in validation set: %d' % len(val_image_list))

print('Building training graph')

Expand Down Expand Up @@ -199,6 +210,11 @@ def main(args):
total_loss, train_op, summary_op, summary_writer, regularization_losses, args.learning_rate_schedule_file,
args.random_rotate, args.random_crop, args.random_flip)

if len(val_image_list)>0 and ((epoch-1) % args.validate_every_n_epochs == args.validate_every_n_epochs-1 or epoch==args.max_nrof_epochs):
validate(args, sess, epoch, val_image_list, val_label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder, control_placeholder,
learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder,
total_loss, regularization_losses, cross_entropy_mean, args.validate_every_n_epochs)

# Save variables and the metagraph if it doesn't exist already
save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step)

Expand Down Expand Up @@ -283,6 +299,41 @@ def train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_o
summary_writer.add_summary(summary, step)
return step

def validate(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder, control_placeholder,
learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder,
loss, regularization_losses, cross_entropy_mean, validate_every_n_epochs):

print('Running forward pass on validation set')

nrof_batches = len(label_list) // args.lfw_batch_size
nrof_images = nrof_batches * args.lfw_batch_size

# Enqueue one epoch of image paths and labels
labels_array = np.expand_dims(np.array(label_list[:nrof_images]),1)
image_paths_array = np.expand_dims(np.array(image_list[:nrof_images]),1)
control_array = np.zeros_like(labels_array, np.int32)
sess.run(enqueue_op, {image_paths_placeholder: image_paths_array, labels_placeholder: labels_array, control_placeholder: control_array})

loss_array = np.zeros((nrof_batches,), np.float32)
xent_array = np.zeros((nrof_batches,), np.float32)

# Training loop
start_time = time.time()
for i in range(nrof_batches):
feed_dict = {phase_train_placeholder:False, batch_size_placeholder:args.lfw_batch_size}
err, cross_entropy_mean_ = sess.run([loss, cross_entropy_mean], feed_dict=feed_dict)
loss_array[i], xent_array[i] = (err, cross_entropy_mean_)
if i % 10 == 9:
print('.', end='')
sys.stdout.flush()
print('')

duration = time.time() - start_time

print('Validation Epoch: %d\tTime %.3f\tLoss %2.3f\tXent %2.3f' %
(epoch, duration, np.mean(loss_array), np.mean(xent_array)))


def evaluate(sess, enqueue_op, image_paths_placeholder, labels_placeholder, phase_train_placeholder, batch_size_placeholder, control_placeholder,
embeddings, labels, image_paths, actual_issame, batch_size, nrof_folds, log_dir, step, summary_writer, distance_metric, subtract_mean, use_flipped_images):
start_time = time.time()
Expand Down Expand Up @@ -429,6 +480,12 @@ def parse_arguments(argv):
help='Keep only the percentile images closed to its class center', default=100.0)
parser.add_argument('--filter_min_nrof_images_per_class', type=int,
help='Keep only the classes with this number of examples or more', default=0)
parser.add_argument('--validate_every_n_epochs', type=int,
help='Number of epoch between validation', default=5)
parser.add_argument('--validation_set_split_ratio', type=float,
help='The ratio of the total dataset to use for validation', default=0.0)
parser.add_argument('--min_nrof_val_images_per_class', type=float,
help='Classes with fewer images will be removed from the validation set', default=0)

# Parameters for validation on LFW
parser.add_argument('--lfw_pairs', type=str,
Expand Down

0 comments on commit 8ebebe4

Please sign in to comment.