Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
yu4u committed Apr 9, 2017
1 parent edcb76d commit e6f0655
Show file tree
Hide file tree
Showing 8 changed files with 834 additions and 0 deletions.
105 changes: 105 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
### https://raw.github.com/github/gitignore/f57304e9762876ae4c9b02867ed0cb887316387e/Python.gitignore

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

/.idea/

/data/
/models/
/checkpoints/
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Age and Gender Estimation
This is a Keras implementation of a CNN network for estimating age and gender from a face image.
In training, [the IMDB-WIKI dataset](https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/) is used.

## Dependencies
- Python3.5+
- Keras
- scipy, numpy, Pandas, tqdm
- OpenCV3

## Usage

Download the dataset. The dataset is downloaded and extracted to the `data` directory.

```sh
./download.sh
```

Filter out noise data and serialize images and labels for training into `.mat` file.
Please check `check_dataset.ipynb` for the details of the dataset.
```sh
python create_db.py --output data/imdb_db.mat --db imdb --img_size 64
```

Train the network using the training data created above.

```sh
python3 train.py --input data/imdb_db.mat
```

## Network architecture
In [the original paper](https://www.vision.ee.ethz.ch/en/publications/papers/articles/eth_biwi_01299.pdf), the pretrained VGG network is adopted.
Here the Wide Residual Network (WideResNet) is trained from scratch.
I modified the @asmith26's implementation of the WideResNet; two classification layers (for age and gender estimation) are added on the top of the WideResNet.
Note that age and gender are estimated independently using different two CNNs.

337 changes: 337 additions & 0 deletions check_dataset.ipynb

Large diffs are not rendered by default.

62 changes: 62 additions & 0 deletions create_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import numpy as np
import cv2
import scipy.io
import argparse
from tqdm import tqdm
from utils import get_meta


def get_args():
parser = argparse.ArgumentParser(description="This script automatically blur input video.")
parser.add_argument("--output", "-o", type=str, required=True,
help="path to output database mat file")
parser.add_argument("--db", type=str, default="wiki",
help="dataset; wiki or imdb")
parser.add_argument("--img_size", type=int, default=32,
help="output image size")
parser.add_argument("--min_score", type=float, default=1.0,
help="minimum face_score")
args = parser.parse_args()
return args


def main():
args = get_args()
output_path = args.output
db = args.db
img_size = args.img_size
min_score = args.min_score

root_path = "data/{}_crop/".format(db)
mat_path = root_path + "{}.mat".format(db)
full_path, dob, gender, photo_taken, face_score, second_face_score, age = get_meta(mat_path, db)

out_genders = []
out_ages = []
out_imgs = []

for i in tqdm(range(len(face_score))):
if face_score[i] < min_score:
continue

if (~np.isnan(second_face_score[i])) and second_face_score[i] > 0.0:
continue

if ~(0 <= age[i] <= 100):
continue

if np.isnan(gender[i]):
continue

out_genders.append(int(gender[i]))
out_ages.append(age[i])
img = cv2.imread(root_path + str(full_path[i][0]))
out_imgs.append(cv2.resize(img, (img_size, img_size)))

output = {"image": np.array(out_imgs), "gender": np.array(out_genders), "age": np.array(out_ages),
"db": db, "img_size": img_size, "min_score": min_score}
scipy.io.savemat(output_path, output)


if __name__ == '__main__':
main()
19 changes: 19 additions & 0 deletions download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash
mkdir -p data
cd data

if [ ! -f imdb_crop.tar ]; then
wget https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/imdb_crop.tar
fi

if [ ! -d imdb_crop ]; then
tar xf imdb_crop.tar
fi

if [ ! -f wiki_crop.tar ]; then
wget https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar
fi

if [ ! -d wiki_crop ]; then
tar xf wiki_crop.tar
fi
94 changes: 94 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pandas as pd
import logging
import argparse
import os
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.optimizers import SGD
from keras.utils import np_utils
from wide_resnet import WideResNet
from utils import mk_dir, load_data

logging.basicConfig(level=logging.DEBUG)


class Schedule:
def __init__(self, nb_epochs):
self.epochs = nb_epochs

def __call__(self, epoch_idx):
if epoch_idx < self.epochs * 0.25:
return 0.1
elif epoch_idx < self.epochs * 0.5:
return 0.02
elif epoch_idx < self.epochs * 0.75:
return 0.004
return 0.0008


def get_args():
parser = argparse.ArgumentParser(description="This script trains the CNN model for age and gender estimation.")
parser.add_argument("--input", "-i", type=str, required=True,
help="path to input database mat file")
parser.add_argument("--batch_size", type=int, default=32,
help="batch size")
parser.add_argument("--nb_epochs", type=int, default=30,
help="number of epochs")
parser.add_argument("--depth", type=int, default=16,
help="depth of network")
parser.add_argument("--width", type=int, default=8,
help="width of network")
parser.add_argument("--validation_split", type=float, default=0.1,
help="validation split ratio")
args = parser.parse_args()
return args


def main():
args = get_args()
input_path = args.input
batch_size = args.batch_size
nb_epochs = args.nb_epochs
depth = args.depth
k = args.width
validation_split = args.validation_split

logging.debug("Loading data...")
image, gender, age, _, image_size, _ = load_data(input_path)
X_data = image
y_data_g = np_utils.to_categorical(gender, 2)
y_data_a = np_utils.to_categorical(age, 101)

model = WideResNet(image_size, depth=depth, k=k)()
sgd = SGD(lr=0.1, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss=["categorical_crossentropy", "categorical_crossentropy"],
metrics=['accuracy'])

logging.debug("Model summary...")
model.count_params()
model.summary()

logging.debug("Saving model...")
mk_dir("models")
with open(os.path.join("models", "WRN_{}_{}.json".format(depth, k)), "w") as f:
f.write(model.to_json())

mk_dir("checkpoints")
callbacks = [LearningRateScheduler(schedule=Schedule(nb_epochs)),
ModelCheckpoint("checkpoints/weights.{epoch:02d}-{val_loss:.2f}.hdf5",
monitor="val_loss",
verbose=1,
save_best_only=True,
mode="auto")
]

logging.debug("Running training...")
hist = model.fit(X_data, [y_data_g, y_data_a], batch_size=batch_size, nb_epoch=nb_epochs, callbacks=callbacks,
validation_split=validation_split)

logging.debug("Saving weights...")
model.save_weights(os.path.join("models", "WRN_{}_{}.h5".format(depth, k)), overwrite=True)
pd.DataFrame(hist.history).to_hdf(os.path.join("models", "history_{}_{}.h5".format(depth, k)), "history")


if __name__ == '__main__':
main()
39 changes: 39 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from scipy.io import loadmat
from datetime import datetime
import os


def calc_age(taken, dob):
birth = datetime.fromordinal(max(int(dob) - 366, 1))

# assume the photo was taken in the middle of the year
if birth.month < 7:
return taken - birth.year
else:
return taken - birth.year - 1


def get_meta(mat_path, db):
meta = loadmat(mat_path)
full_path = meta[db][0, 0]["full_path"][0]
dob = meta[db][0, 0]["dob"][0] # Matlab serial date number
gender = meta[db][0, 0]["gender"][0]
photo_taken = meta[db][0, 0]["photo_taken"][0] # year
face_score = meta[db][0, 0]["face_score"][0]
second_face_score = meta[db][0, 0]["second_face_score"][0]
age = [calc_age(photo_taken[i], dob[i]) for i in range(len(dob))]

return full_path, dob, gender, photo_taken, face_score, second_face_score, age


def load_data(mat_path):
d = loadmat(mat_path)

return d["image"], d["gender"][0], d["age"][0], d["db"][0], d["img_size"][0, 0], d["min_score"][0, 0]


def mk_dir(dir):
try:
os.mkdir( dir )
except OSError:
pass
Loading

0 comments on commit e6f0655

Please sign in to comment.