Skip to content

Commit

Permalink
mae
Browse files Browse the repository at this point in the history
  • Loading branch information
yusuke-a-uchida committed Aug 5, 2018
1 parent 655c0ed commit 8b0e683
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
8 changes: 8 additions & 0 deletions age_estimation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
from keras.applications import ResNet50, InceptionResNetV2
from keras.layers import Dense
from keras.models import Model
from keras import backend as K


def age_mae(y_true, y_pred):
true_age = K.mean(y_true * K.arange(0, 101), axis=-1)
pred_age = K.mean(y_pred * K.arange(0, 101), axis=-1)
mae = K.mean(K.abs(true_age - pred_age))
return mae


def get_model(model_name="ResNet50"):
Expand Down
12 changes: 6 additions & 6 deletions age_estimation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.optimizers import SGD
from generator import FaceGenerator, ValGenerator
from model import get_model
from model import get_model, age_mae


class Schedule:
def __init__(self, nb_epochs):
Expand Down Expand Up @@ -56,17 +57,16 @@ def main():
val_gen = ValGenerator(appa_dir, batch_size=batch_size, image_size=image_size)
model = get_model(model_name=model_name)
sgd = SGD(lr=0.1, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss="categorical_crossentropy",
metrics=['accuracy'])
model.compile(optimizer=sgd, loss="categorical_crossentropy", metrics=[age_mae])
model.summary()
output_dir = Path(__file__).resolve().parent.joinpath(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
callbacks = [LearningRateScheduler(schedule=Schedule(nb_epochs)),
ModelCheckpoint(str(output_dir) + "/weights.{epoch:03d}-{val_loss:.3f}.hdf5",
monitor="val_loss",
ModelCheckpoint(str(output_dir) + "/weights.{epoch:03d}-{val_loss:.3f}-{val_age_mae:.3f}.hdf5",
monitor="val_age_mae",
verbose=1,
save_best_only=True,
mode="auto")
mode=min")
]

hist = model.fit_generator(generator=train_gen,
Expand Down

0 comments on commit 8b0e683

Please sign in to comment.