Skip to content

Commit

Permalink
add demo
Browse files Browse the repository at this point in the history
  • Loading branch information
yu4u committed Apr 15, 2017
1 parent e4a2172 commit 36c43d8
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 6 deletions.
32 changes: 26 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,50 @@ In training, [the IMDB-WIKI dataset](https://data.vision.ee.ethz.ch/cvl/rrothe/i


## Dependencies
Tested on Ubuntu 16.04, Python 3.5.2, CUDA 8.0, cuDNN 5.0.

- Python3.5+
- Keras
- scipy, numpy, Pandas, tqdm
- scipy, numpy, Pandas, tqdm, tables, h5py
- dlib (for demo)
- OpenCV3

Tested on:
- Ubuntu 16.04, Python 3.5.2, Keras 2.0.3, Tensorflow(-gpu) 1.0.1, CUDA 8.0, cuDNN 5.0
- macOS Sierra, Python 3.6.0, Keras 2.0.2, Tensorflow 1.0.0


## Usage
### Download the dataset

### Use pretrained model
Downlaod pretrained model

```sh
mkdir -p pretrained_model
wegt -O pretrained_model https://www.dropbox.com/s/rf8hgoev8uqjv3z/weights.18-4.06.hdf5
```

Run demo script (requires web cam)

```sh
python3 demo.py
```

### Train a model using the IMDB-WIKI dataset

#### Download the dataset
The dataset is downloaded and extracted to the `data` directory.

```sh
./download.sh
```

### Create data
#### Create data
Filter out noise data and serialize images and labels for training into `.mat` file.
Please check `check_dataset.ipynb` for the details of the dataset.
```sh
python3 create_db.py --output data/imdb_db.mat --db imdb --img_size 64
```

### Train network
#### Train network
Train the network using the training data created above.

```sh
Expand Down
70 changes: 70 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import cv2
import dlib
import numpy as np
from wide_resnet import WideResNet


def draw_label(image, point, label, font=cv2.FONT_HERSHEY_SIMPLEX,
font_scale=1, thickness=2):
size = cv2.getTextSize(label, font, font_scale, thickness)[0]
x, y = point
cv2.rectangle(image, (x, y - size[1]), (x + size[0], y), (255, 0, 0), cv2.FILLED)
cv2.putText(image, label, point, font, font_scale, (255, 255, 255), thickness)


def main():
# for face detection
detector = dlib.get_frontal_face_detector()

# load model and weights
img_size = 64
model = WideResNet(img_size, depth=16, k=8)()
model.load_weights(os.path.join("pretrained_models", "weights.18-4.06.hdf5"))

# capture video
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)

while True:
# get video frame
ret, img = cap.read()
input_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_h, img_w, _ = np.shape(input_img)

# detect faces using dlib detector
detected = detector(input_img, 1)
faces = np.empty((len(detected), img_size, img_size, 3))

for i, d in enumerate(detected):
x1, y1, x2, y2, w, h = d.left(), d.top(), d.right() + 1, d.bottom() + 1, d.width(), d.height()
xw1 = max(int(x1 - 0.4 * w), 0)
yw1 = max(int(y1 - 0.4 * h), 0)
xw2 = min(int(x2 + 0.4 * w), img_w - 1)
yw2 = min(int(y2 + 0.4 * h), img_h - 1)
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
# cv2.rectangle(img, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2)
faces[i,:,:,:] = cv2.resize(img[yw1:yw2 + 1, xw1:xw2 + 1, :], (img_size, img_size))

# predict ages and genders of the detected faces
results = model.predict(faces)
predicted_genders = results[0]
ages = np.arange(0, 101).reshape(101, 1)
predicted_ages = results[1].dot(ages).flatten()

# draw results
for i, d in enumerate(detected):
label = "{}, {}".format(int(predicted_ages[i]),
"F" if predicted_genders[i][0] > 0.5 else "M")
draw_label(img, (d.left(), d.top()), label)

cv2.imshow("result", img)
key = cv2.waitKey(30)

if key == 27:
break


if __name__ == '__main__':
main()

0 comments on commit 36c43d8

Please sign in to comment.