Skip to content

Commit

Permalink
add options for lr and opt
Browse files Browse the repository at this point in the history
  • Loading branch information
yusuke-a-uchida committed Nov 11, 2018
1 parent 6dea1cc commit a639c7b
Showing 1 changed file with 33 additions and 17 deletions.
50 changes: 33 additions & 17 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
import numpy as np
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.optimizers import SGD
from keras.optimizers import SGD, Adam
from keras.utils import np_utils
from wide_resnet import WideResNet
from utils import load_data
Expand All @@ -15,20 +15,6 @@
logging.basicConfig(level=logging.DEBUG)


class Schedule:
def __init__(self, nb_epochs):
self.epochs = nb_epochs

def __call__(self, epoch_idx):
if epoch_idx < self.epochs * 0.25:
return 0.1
elif epoch_idx < self.epochs * 0.5:
return 0.02
elif epoch_idx < self.epochs * 0.75:
return 0.004
return 0.0008


def get_args():
parser = argparse.ArgumentParser(description="This script trains the CNN model for age and gender estimation.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand All @@ -38,6 +24,10 @@ def get_args():
help="batch size")
parser.add_argument("--nb_epochs", type=int, default=30,
help="number of epochs")
parser.add_argument("--lr", type=float, default=0.1,
help="initial learning rate")
parser.add_argument("--opt", type=str, default="sgd",
help="optimizer name; 'sgd' or 'adam'")
parser.add_argument("--depth", type=int, default=16,
help="depth of network (should be 10, 16, 22, 28, ...)")
parser.add_argument("--width", type=int, default=8,
Expand All @@ -52,11 +42,37 @@ def get_args():
return args


class Schedule:
def __init__(self, nb_epochs, initial_lr):
self.epochs = nb_epochs
self.initial_lr = initial_lr

def __call__(self, epoch_idx):
if epoch_idx < self.epochs * 0.25:
return self.initial_lr
elif epoch_idx < self.epochs * 0.50:
return self.initial_lr * 0.2
elif epoch_idx < self.epochs * 0.75:
return self.initial_lr * 0.04
return self.initial_lr * 0.008


def get_optimizer(opt_name, lr):
if opt_name == "sgd":
return SGD(lr=lr, momentum=0.9, nesterov=True)
elif opt_name == "adam":
return Adam(lr=lr)
else:
raise ValueError("optimizer name should be 'sgd' or 'adam'")


def main():
args = get_args()
input_path = args.input
batch_size = args.batch_size
nb_epochs = args.nb_epochs
lr = args.lr
opt_name = args.opt
depth = args.depth
k = args.width
validation_split = args.validation_split
Expand All @@ -71,8 +87,8 @@ def main():
y_data_a = np_utils.to_categorical(age, 101)

model = WideResNet(image_size, depth=depth, k=k)()
sgd = SGD(lr=0.1, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss=["categorical_crossentropy", "categorical_crossentropy"],
opt = get_optimizer(opt_name, lr)
model.compile(optimizer=opt, loss=["categorical_crossentropy", "categorical_crossentropy"],
metrics=['accuracy'])

logging.debug("Model summary...")
Expand Down

0 comments on commit a639c7b

Please sign in to comment.