Skip to content

Commit

Permalink
Refine image classification prediction code.
Browse files Browse the repository at this point in the history
ISSUE=4596311

git-svn-id: https://svn.baidu.com/idl/trunk/paddle@1431 1ad973e4-5ce8-4261-8a94-b56d1f490c56
  • Loading branch information
qingqing01 committed Aug 30, 2016
1 parent 5db9e59 commit 13f4602
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 35 deletions.
1 change: 0 additions & 1 deletion demo/image_classification/classify.py

This file was deleted.

27 changes: 0 additions & 27 deletions demo/image_classification/image_predictor.py

This file was deleted.

10 changes: 4 additions & 6 deletions demo/image_classification/classify.sh → demo/image_classification/predict.sh
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
# limitations under the License.
set -e

python classify.py \
--job=predict \
--conf=vgg_16_cifar.py \
--model=./cifar_vgg_model/pass-00299 \
--multi_crop \
--data=./example/test.list
model=cifar_vgg_model/pass-00299/
image=data/cifar-out/test/airplane/seaplane_s_000978.png
use_gpu=1
python prediction.py $model $image $use_gpu
155 changes: 155 additions & 0 deletions demo/image_classification/prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os,sys
import numpy as np
import logging
from PIL import Image
from optparse import OptionParser

import paddle.utils.image_util as image_util

from py_paddle import swig_paddle, util
from py_paddle import DataProviderWrapperConverter
from paddle.trainer.PyDataProviderWrapper import DenseSlot
from paddle.trainer.config_parser import parse_config

logging.basicConfig(format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s')
logging.getLogger().setLevel(logging.INFO)

class ImageClassifier():
def __init__(self,
train_conf,
use_gpu=True,
model_dir=None,
resize_dim=None,
crop_dim=None,
mean_file=None,
oversample=False,
is_color=True):
"""
train_conf: network configure.
model_dir: string, directory of model.
resize_dim: int, resized image size.
crop_dim: int, crop size.
mean_file: string, image mean file.
oversample: bool, oversample means multiple crops, namely five
patches (the four corner patches and the center
patch) as well as their horizontal reflections,
ten crops in all.
"""
self.train_conf = train_conf
self.model_dir = model_dir
if model_dir is None:
self.model_dir = os.path.dirname(train_conf)

self.resize_dim = resize_dim
self.crop_dims = [crop_dim, crop_dim]
self.oversample = oversample
self.is_color = is_color

self.transformer = image_util.ImageTransformer(is_color = is_color)
self.transformer.set_transpose((2,0,1))

self.mean_file = mean_file
mean = np.load(self.mean_file)['data_mean']
mean = mean.reshape(3, self.crop_dims[0], self.crop_dims[1])
self.transformer.set_mean(mean) # mean pixel
gpu = 1 if use_gpu else 0
conf_args = "is_test=1,use_gpu=%d,is_predict=1" % (gpu)
conf = parse_config(train_conf, conf_args)
swig_paddle.initPaddle("--use_gpu=%d" % (gpu))
self.network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
assert isinstance(self.network, swig_paddle.GradientMachine)
self.network.loadParameters(self.model_dir)

data_size = 3 * self.crop_dims[0] * self.crop_dims[1]
slots = [DenseSlot(data_size)]
self.converter = util.DataProviderWrapperConverter(False, slots)

def get_data(self, img_path):
"""
1. load image from img_path.
2. resize or oversampling.
3. transformer data: transpose, sub mean.
return K x H x W ndarray.
img_path: image path.
"""
image = image_util.load_image(img_path, self.is_color)
if self.oversample:
# image_util.resize_image: short side is self.resize_dim
image = image_util.resize_image(image, self.resize_dim)
image = np.array(image)
input = np.zeros((1, image.shape[0], image.shape[1], 3),
dtype=np.float32)
input[0] = image.astype(np.float32)
input = image_util.oversample(input, self.crop_dims)
else:
image = image.resize(self.crop_dims, Image.ANTIALIAS)
input = np.zeros((1, self.crop_dims[0], self.crop_dims[1], 3),
dtype=np.float32)
input[0] = np.array(image).astype(np.float32)

data_in = []
for img in input:
img = self.transformer.transformer(img).flatten()
data_in.append([img.tolist()])
return data_in

def forward(self, input_data):
in_arg = self.converter(input_data)
return self.network.forwardTest(in_arg)

def forward(self, data, output_layer):
"""
input_data: py_paddle input data.
output_layer: specify the name of probability, namely the layer with
softmax activation.
return: the predicting probability of each label.
"""
input = self.converter(data)
self.network.forwardTest(input)
output = self.network.getLayerOutputs(output_layer)
# For oversampling, average predictions across crops.
# If not, the shape of output[name]: (1, class_number),
# the mean is also applicable.
return output[output_layer].mean(0)

def predict(self, image=None, output_layer=None):
assert isinstance(image, basestring)
assert isinstance(output_layer, basestring)
data = self.get_data(image)
prob = self.forward(data, output_layer)
lab = np.argsort(-prob)
logging.info("Label of %s is: %d", image, lab[0])

if __name__ == '__main__':
image_size=32
crop_size=32
multi_crop=True
config="vgg_16_cifar.py"
output_layer="__fc_layer_1__"
mean_path="data/cifar-out/batches/batches.meta"
model_path=sys.argv[1]
image=sys.argv[2]
use_gpu=bool(int(sys.argv[3]))

obj = ImageClassifier(train_conf=config,
model_dir=model_path,
resize_dim=image_size,
crop_dim=crop_size,
mean_file=mean_path,
use_gpu=use_gpu,
oversample=multi_crop)
obj.predict(image, output_layer)
13 changes: 12 additions & 1 deletion doc/demo/image_classification/image_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,19 @@ After training finishes, the training and testing error curve will be saved to `
## Prediction
After we train the model, the model file as well as the model parameters are stored in path `./cifar_vgg_model/pass-%05d`. For example, the model of the 300-th pass is stored at `./cifar_vgg_model/pass-00299`.

To make a prediction for an image, such as `test.jpg`, one can run `sh classify.sh ./cifar_vgg_model/pass-00299 test.jpg`. The script will output the label of the classfiication.
To make a prediction for an image, one can run `predict.sh` as follows. The script will output the label of the classfiication.

```
sh predict.sh
```

predict.sh:
```
model=cifar_vgg_model/pass-00299/
image=data/cifar-out/test/airplane/seaplane_s_000978.png
use_gpu=1
python prediction.py $model $image $use_gpu
```

## Exercise
Train a image classification of birds using VGG model and CUB-200 dataset. The birds dataset can be downloaded here. It contains an image dataset with photos of 200 bird species (mostly North American).
Expand Down

0 comments on commit 13f4602

Please sign in to comment.