Skip to content

Commit

Permalink
add adam
Browse files Browse the repository at this point in the history
  • Loading branch information
yusuke-a-uchida committed Nov 11, 2018
1 parent 1b7c8d7 commit f0ac0a3
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions age_estimation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
import numpy as np
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.optimizers import SGD
from keras.optimizers import SGD, Adam
from generator import FaceGenerator, ValGenerator
from model import get_model, age_mae

Expand All @@ -22,6 +22,8 @@ def get_args():
help="number of epochs")
parser.add_argument("--lr", type=float, default=0.1,
help="learning rate")
parser.add_argument("--opt", type=str, default="sgd",
help="optimizer name; 'sgd' or 'adam'")
parser.add_argument("--model_name", type=str, default="ResNet50",
help="model name: 'ResNet50' or 'InceptionResNetV2'")
args = parser.parse_args()
Expand All @@ -43,6 +45,15 @@ def __call__(self, epoch_idx):
return self.initial_lr * 0.008


def get_optimizer(opt_name, lr):
if opt_name == "sgd":
return SGD(lr=lr, momentum=0.9, nesterov=True)
elif opt_name == "adam":
return Adam(lr=lr)
else:
raise ValueError("optimizer name should be 'sgd' or 'adam'")


def main():
args = get_args()
appa_dir = args.appa_dir
Expand All @@ -51,6 +62,7 @@ def main():
batch_size = args.batch_size
nb_epochs = args.nb_epochs
lr = args.lr
opt_name = args.opt

if model_name == "ResNet50":
image_size = 224
Expand All @@ -60,8 +72,8 @@ def main():
train_gen = FaceGenerator(appa_dir, utk_dir=utk_dir, batch_size=batch_size, image_size=image_size)
val_gen = ValGenerator(appa_dir, batch_size=batch_size, image_size=image_size)
model = get_model(model_name=model_name)
sgd = SGD(lr=0.1, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss="categorical_crossentropy", metrics=[age_mae])
opt = get_optimizer(opt_name, lr)
model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=[age_mae])
model.summary()
output_dir = Path(__file__).resolve().parent.joinpath(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
Expand Down

0 comments on commit f0ac0a3

Please sign in to comment.