Skip to content

Commit

Permalink
lr
Browse files Browse the repository at this point in the history
  • Loading branch information
yusuke-a-uchida committed Aug 5, 2018
1 parent 111beb5 commit f05ee66
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions age_estimation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,50 @@
from model import get_model, age_mae


class Schedule:
def __init__(self, nb_epochs):
self.epochs = nb_epochs

def __call__(self, epoch_idx):
if epoch_idx < self.epochs * 0.25:
return 0.1
elif epoch_idx < self.epochs * 0.5:
return 0.02
elif epoch_idx < self.epochs * 0.75:
return 0.004
return 0.0008


def get_args():
parser = argparse.ArgumentParser(description="This script trains the CNN model for age estimation.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--appa_dir", type=str, required=True,
help="path to the APPA-REAL dataset")
parser.add_argument("--utk_dir", type=str, default=None,
help="path to the UTK face dataset")
parser.add_argument("--output_dir", type=str, default="checkpoints",
help="checkpoint dir")
parser.add_argument("--batch_size", type=int, default=32,
help="batch size")
parser.add_argument("--nb_epochs", type=int, default=30,
help="number of epochs")
parser.add_argument("--lr", type=float, default=0.1,
help="learning rate")
parser.add_argument("--model_name", type=str, default="ResNet50",
help="model name: 'ResNet50' or 'InceptionResNetV2'")
parser.add_argument("--output_dir", type=str, default="checkpoints",
help="checkpoint dir")
args = parser.parse_args()
return args


class Schedule:
def __init__(self, nb_epochs, initial_lr):
self.epochs = nb_epochs
self.initial_lr = initial_lr

def __call__(self, epoch_idx):
if epoch_idx < self.epochs * 0.25:
return self.initial_lr
elif epoch_idx < self.epochs * 0.50:
return self.initial_lr * 0.2
elif epoch_idx < self.epochs * 0.75:
return self.initial_lr * 0.04
return self.initial_lr * 0.008


def main():
args = get_args()
appa_dir = args.appa_dir
utk_dir = args.utk_dir
model_name = args.model_name
batch_size = args.batch_size
nb_epochs = args.nb_epochs
model_name = args.model_name
lr = args.lr

if model_name == "ResNet50":
image_size = 224
Expand All @@ -61,7 +65,7 @@ def main():
model.summary()
output_dir = Path(__file__).resolve().parent.joinpath(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
callbacks = [LearningRateScheduler(schedule=Schedule(nb_epochs)),
callbacks = [LearningRateScheduler(schedule=Schedule(nb_epochs, initial_lr=lr)),
ModelCheckpoint(str(output_dir) + "/weights.{epoch:03d}-{val_loss:.3f}-{val_age_mae:.3f}.hdf5",
monitor="val_age_mae",
verbose=1,
Expand Down

0 comments on commit f05ee66

Please sign in to comment.