forked from ageitgey/age-gender-estimation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
834 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
### https://raw.github.com/github/gitignore/f57304e9762876ae4c9b02867ed0cb887316387e/Python.gitignore | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*,cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# dotenv | ||
.env | ||
|
||
# virtualenv | ||
.venv | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
/.idea/ | ||
|
||
/data/ | ||
/models/ | ||
/checkpoints/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Age and Gender Estimation | ||
This is a Keras implementation of a CNN network for estimating age and gender from a face image. | ||
In training, [the IMDB-WIKI dataset](https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/) is used. | ||
|
||
## Dependencies | ||
- Python3.5+ | ||
- Keras | ||
- scipy, numpy, Pandas, tqdm | ||
- OpenCV3 | ||
|
||
## Usage | ||
|
||
Download the dataset. The dataset is downloaded and extracted to the `data` directory. | ||
|
||
```sh | ||
./download.sh | ||
``` | ||
|
||
Filter out noise data and serialize images and labels for training into `.mat` file. | ||
Please check `check_dataset.ipynb` for the details of the dataset. | ||
```sh | ||
python create_db.py --output data/imdb_db.mat --db imdb --img_size 64 | ||
``` | ||
|
||
Train the network using the training data created above. | ||
|
||
```sh | ||
python3 train.py --input data/imdb_db.mat | ||
``` | ||
|
||
## Network architecture | ||
In [the original paper](https://www.vision.ee.ethz.ch/en/publications/papers/articles/eth_biwi_01299.pdf), the pretrained VGG network is adopted. | ||
Here the Wide Residual Network (WideResNet) is trained from scratch. | ||
I modified the @asmith26's implementation of the WideResNet; two classification layers (for age and gender estimation) are added on the top of the WideResNet. | ||
Note that age and gender are estimated independently using different two CNNs. | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import numpy as np | ||
import cv2 | ||
import scipy.io | ||
import argparse | ||
from tqdm import tqdm | ||
from utils import get_meta | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser(description="This script automatically blur input video.") | ||
parser.add_argument("--output", "-o", type=str, required=True, | ||
help="path to output database mat file") | ||
parser.add_argument("--db", type=str, default="wiki", | ||
help="dataset; wiki or imdb") | ||
parser.add_argument("--img_size", type=int, default=32, | ||
help="output image size") | ||
parser.add_argument("--min_score", type=float, default=1.0, | ||
help="minimum face_score") | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
output_path = args.output | ||
db = args.db | ||
img_size = args.img_size | ||
min_score = args.min_score | ||
|
||
root_path = "data/{}_crop/".format(db) | ||
mat_path = root_path + "{}.mat".format(db) | ||
full_path, dob, gender, photo_taken, face_score, second_face_score, age = get_meta(mat_path, db) | ||
|
||
out_genders = [] | ||
out_ages = [] | ||
out_imgs = [] | ||
|
||
for i in tqdm(range(len(face_score))): | ||
if face_score[i] < min_score: | ||
continue | ||
|
||
if (~np.isnan(second_face_score[i])) and second_face_score[i] > 0.0: | ||
continue | ||
|
||
if ~(0 <= age[i] <= 100): | ||
continue | ||
|
||
if np.isnan(gender[i]): | ||
continue | ||
|
||
out_genders.append(int(gender[i])) | ||
out_ages.append(age[i]) | ||
img = cv2.imread(root_path + str(full_path[i][0])) | ||
out_imgs.append(cv2.resize(img, (img_size, img_size))) | ||
|
||
output = {"image": np.array(out_imgs), "gender": np.array(out_genders), "age": np.array(out_ages), | ||
"db": db, "img_size": img_size, "min_score": min_score} | ||
scipy.io.savemat(output_path, output) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/bin/bash | ||
mkdir -p data | ||
cd data | ||
|
||
if [ ! -f imdb_crop.tar ]; then | ||
wget https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/imdb_crop.tar | ||
fi | ||
|
||
if [ ! -d imdb_crop ]; then | ||
tar xf imdb_crop.tar | ||
fi | ||
|
||
if [ ! -f wiki_crop.tar ]; then | ||
wget https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar | ||
fi | ||
|
||
if [ ! -d wiki_crop ]; then | ||
tar xf wiki_crop.tar | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import pandas as pd | ||
import logging | ||
import argparse | ||
import os | ||
from keras.callbacks import LearningRateScheduler, ModelCheckpoint | ||
from keras.optimizers import SGD | ||
from keras.utils import np_utils | ||
from wide_resnet import WideResNet | ||
from utils import mk_dir, load_data | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
|
||
class Schedule: | ||
def __init__(self, nb_epochs): | ||
self.epochs = nb_epochs | ||
|
||
def __call__(self, epoch_idx): | ||
if epoch_idx < self.epochs * 0.25: | ||
return 0.1 | ||
elif epoch_idx < self.epochs * 0.5: | ||
return 0.02 | ||
elif epoch_idx < self.epochs * 0.75: | ||
return 0.004 | ||
return 0.0008 | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser(description="This script trains the CNN model for age and gender estimation.") | ||
parser.add_argument("--input", "-i", type=str, required=True, | ||
help="path to input database mat file") | ||
parser.add_argument("--batch_size", type=int, default=32, | ||
help="batch size") | ||
parser.add_argument("--nb_epochs", type=int, default=30, | ||
help="number of epochs") | ||
parser.add_argument("--depth", type=int, default=16, | ||
help="depth of network") | ||
parser.add_argument("--width", type=int, default=8, | ||
help="width of network") | ||
parser.add_argument("--validation_split", type=float, default=0.1, | ||
help="validation split ratio") | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
input_path = args.input | ||
batch_size = args.batch_size | ||
nb_epochs = args.nb_epochs | ||
depth = args.depth | ||
k = args.width | ||
validation_split = args.validation_split | ||
|
||
logging.debug("Loading data...") | ||
image, gender, age, _, image_size, _ = load_data(input_path) | ||
X_data = image | ||
y_data_g = np_utils.to_categorical(gender, 2) | ||
y_data_a = np_utils.to_categorical(age, 101) | ||
|
||
model = WideResNet(image_size, depth=depth, k=k)() | ||
sgd = SGD(lr=0.1, momentum=0.9, nesterov=True) | ||
model.compile(optimizer=sgd, loss=["categorical_crossentropy", "categorical_crossentropy"], | ||
metrics=['accuracy']) | ||
|
||
logging.debug("Model summary...") | ||
model.count_params() | ||
model.summary() | ||
|
||
logging.debug("Saving model...") | ||
mk_dir("models") | ||
with open(os.path.join("models", "WRN_{}_{}.json".format(depth, k)), "w") as f: | ||
f.write(model.to_json()) | ||
|
||
mk_dir("checkpoints") | ||
callbacks = [LearningRateScheduler(schedule=Schedule(nb_epochs)), | ||
ModelCheckpoint("checkpoints/weights.{epoch:02d}-{val_loss:.2f}.hdf5", | ||
monitor="val_loss", | ||
verbose=1, | ||
save_best_only=True, | ||
mode="auto") | ||
] | ||
|
||
logging.debug("Running training...") | ||
hist = model.fit(X_data, [y_data_g, y_data_a], batch_size=batch_size, nb_epoch=nb_epochs, callbacks=callbacks, | ||
validation_split=validation_split) | ||
|
||
logging.debug("Saving weights...") | ||
model.save_weights(os.path.join("models", "WRN_{}_{}.h5".format(depth, k)), overwrite=True) | ||
pd.DataFrame(hist.history).to_hdf(os.path.join("models", "history_{}_{}.h5".format(depth, k)), "history") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from scipy.io import loadmat | ||
from datetime import datetime | ||
import os | ||
|
||
|
||
def calc_age(taken, dob): | ||
birth = datetime.fromordinal(max(int(dob) - 366, 1)) | ||
|
||
# assume the photo was taken in the middle of the year | ||
if birth.month < 7: | ||
return taken - birth.year | ||
else: | ||
return taken - birth.year - 1 | ||
|
||
|
||
def get_meta(mat_path, db): | ||
meta = loadmat(mat_path) | ||
full_path = meta[db][0, 0]["full_path"][0] | ||
dob = meta[db][0, 0]["dob"][0] # Matlab serial date number | ||
gender = meta[db][0, 0]["gender"][0] | ||
photo_taken = meta[db][0, 0]["photo_taken"][0] # year | ||
face_score = meta[db][0, 0]["face_score"][0] | ||
second_face_score = meta[db][0, 0]["second_face_score"][0] | ||
age = [calc_age(photo_taken[i], dob[i]) for i in range(len(dob))] | ||
|
||
return full_path, dob, gender, photo_taken, face_score, second_face_score, age | ||
|
||
|
||
def load_data(mat_path): | ||
d = loadmat(mat_path) | ||
|
||
return d["image"], d["gender"][0], d["age"][0], d["db"][0], d["img_size"][0, 0], d["min_score"][0, 0] | ||
|
||
|
||
def mk_dir(dir): | ||
try: | ||
os.mkdir( dir ) | ||
except OSError: | ||
pass |
Oops, something went wrong.