Skip to content

Commit

Permalink
Update yad2k script to Keras 2 API.
Browse files Browse the repository at this point in the history
  • Loading branch information
allanzelener committed Apr 18, 2017
1 parent d4f632b commit 7560ba5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 25 deletions.
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ Original paper: [YOLO9000: Better, Faster, Stronger](https://arxiv.org/abs/1612.
- [Keras](https://github.com/fchollet/keras)
- [Tensorflow](https://www.tensorflow.org/)
- [Numpy](http://www.numpy.org/)
- [h5py](http://www.h5py.org/) (For Keras model serialization.)
- [Pillow](https://pillow.readthedocs.io/) (For rendering test results.)
- [Python 3](https://www.python.org/)
- [pydot-ng](https://github.com/pydot/pydot-ng) (Optional for plotting model.)

### Installation
```bash
Expand All @@ -31,21 +33,21 @@ cd yad2k
conda env create -f environment.yml
source activate yad2k
# [Option 2] Install everything globaly.
pip install numpy

pip install numpy h5py pillow
pip install tensorflow-gpu # CPU-only: conda install -c conda-forge tensorflow
pip install keras # Possibly older release: conda install keras
```

## Quick Start

- Download Darknet model weights from the [official YOLO website](http://pjreddie.com/darknet/yolo/).
- Download Darknet model cfg and weights from the [official YOLO website](http://pjreddie.com/darknet/yolo/).
- Convert the Darknet YOLO_v2 model to a Keras model.
- Test the converted model on the small test set in `images/`.

```bash
wget http://pjreddie.com/media/files/yolo.weights
./yad2k.py cfg/yolo.cfg yolo.weights model_data/yolo.h5
wget https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolo.cfg
./yad2k.py yolo.cfg yolo.weights model_data/yolo.h5
./test_yolo.py model_data/yolo.h5 # output in images/out/
```

Expand All @@ -65,10 +67,11 @@ YAD2K assumes the Keras backend is Tensorflow. In particular for YOLO_v2 models

`yad2k/models` contains reference implementations of Darknet-19 and YOLO_v2.

`train_overfit` is a sample training script that overfits a YOLO_v2 model to a single image from the Pascal VOC dataset.

## Known Issues and TODOs

- Add YOLO_v2 loss function. (In-progress implementation on branch allanzelener/initial_loss_implementation).
- Script to train YOLO_v2 reference model.
- Expand sample training script to train YOLO_v2 reference model on full dataset.
- Support for additional Darknet layer types.
- Tuck away the Tensorflow dependencies with Keras wrappers where possible.
- YOLO_v2 model does not support fully convolutional mode. Current implementation assumes 1:1 aspect ratio images.
Expand Down
37 changes: 18 additions & 19 deletions yad2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

import numpy as np
from keras import backend as K
from keras.layers import (Convolution2D, GlobalAveragePooling2D, Input, Lambda,
MaxPooling2D, merge)
from keras.layers import (Conv2D, GlobalAveragePooling2D, Input, Lambda,
MaxPooling2D)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.merge import concatenate
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from keras.regularizers import l2
Expand Down Expand Up @@ -111,8 +112,8 @@ def _main(args):
activation = cfg_parser[section]['activation']
batch_normalize = 'batch_normalize' in cfg_parser[section]

# border_mode='same' is equivalent to Darknet pad=1
border_mode = 'same' if pad == 1 else 'valid'
# padding='same' is equivalent to Darknet pad=1
padding = 'same' if pad == 1 else 'valid'

# Setting weights.
# Darknet serializes convolutional weights as:
Expand Down Expand Up @@ -175,16 +176,14 @@ def _main(args):
activation, section))

# Create Conv2D layer
conv_layer = (Convolution2D(
filters,
size,
size,
border_mode=border_mode,
subsample=(stride, stride),
bias=not batch_normalize,
conv_layer = (Conv2D(
filters, (size, size),
strides=(stride, stride),
kernel_regularizer=l2(weight_decay),
use_bias=not batch_normalize,
weights=conv_weights,
activation=act_fn,
W_regularizer=l2(weight_decay)))(prev_layer)
padding=padding))(prev_layer)

if batch_normalize:
conv_layer = (BatchNormalization(
Expand All @@ -203,9 +202,9 @@ def _main(args):
stride = int(cfg_parser[section]['stride'])
all_layers.append(
MaxPooling2D(
padding='same',
pool_size=(size, size),
strides=(stride, stride),
border_mode='same')(prev_layer))
strides=(stride, stride))(prev_layer))
prev_layer = all_layers[-1]

elif section.startswith('avgpool'):
Expand All @@ -218,10 +217,10 @@ def _main(args):
ids = [int(i) for i in cfg_parser[section]['layers'].split(',')]
layers = [all_layers[i] for i in ids]
if len(layers) > 1:
print('Merging layers:', layers)
merge_layer = merge(layers, mode='concat')
all_layers.append(merge_layer)
prev_layer = merge_layer
print('Concatenating route layers:', layers)
concatenate_layer = concatenate(layers)
all_layers.append(concatenate_layer)
prev_layer = concatenate_layer
else:
skip_layer = layers[0] # only one layer to route
all_layers.append(skip_layer)
Expand Down Expand Up @@ -250,7 +249,7 @@ def _main(args):
'Unsupported section header type: {}'.format(section))

# Create and save model.
model = Model(input=all_layers[0], output=all_layers[-1])
model = Model(inputs=all_layers[0], outputs=all_layers[-1])
print(model.summary())
model.save('{}'.format(output_path))
print('Saved Keras model to {}'.format(output_path))
Expand Down

0 comments on commit 7560ba5

Please sign in to comment.