ImageAI provides the most simple and powerful approach to training custom image prediction models
using state-of-the-art SqueezeNet, ResNet50, InceptionV3 and DenseNet
which you can load into the imageai.Prediction.Custom.CustomImagePrediction
class. This allows
you to train your own model on any set of images that corresponds to any type of objects/persons.
The training process generates a JSON file that maps the objects types in your image dataset
and creates lots of models. You will then pick the model with the highest accuracy and perform custom
image prediction using the model and the JSON file generated.
- 🔳 Custom Model Training Prediction
- 🔳 Saving Full Custom Model (NEW)
- 🔳 Training on the IdenProf Dataset
- 🔳 Continuous Model Training (NEW)
- 🔳 Transfer Learning (Training from a pre-trained model) (NEW)
Because model training is a compute intensive tasks, we strongly advise you perform this experiment using a computer with a NVIDIA GPU and the GPU version of Tensorflow installed. Performing model training on CPU will my take hours or days. With NVIDIA GPU powered computer system, this will take a few hours. You can use Google Colab for this experiment as it has an NVIDIA K80 GPU available.
To train a custom prediction model, you need to prepare the images you want to use to train the model. You will prepare the images as follows:
- Create a dataset folder with the name you will like your dataset to be called (e.g pets)
- In the dataset folder, create a folder by the name train
- In the dataset folder, create a folder by the name test
- In the train folder, create a folder for each object you want to the model to predict and give the folder a name that corresponds to the respective object name (e.g dog, cat, squirrel, snake)
- In the test folder, create a folder for each object you want to the model to predict and give the folder a name that corresponds to the respective object name (e.g dog, cat, squirrel, snake)
- In each folder present in the train folder, put the images of each object in its respective folder. This images are the ones to be used to train the model To produce a model that can perform well in practical applications, I recommend you about 500 or more images per object. 1000 images per object is just great
- In each folder present in the test folder, put about 100 to 200 images of each object in its respective folder. These images are the ones to be used to test the model as it trains
- Once you have done this, the structure of your image dataset folder should look like below:
pets//train//dog//dog-train-images pets//train//cat//cat-train-images pets//train//squirrel//squirrel-train-images pets//train//snake//snake-train-images pets//test//dog//dog-test-images pets//test//cat//cat-test-images pets//test//squirrel//squirrel-test-images pets//test//snake//snake-test-images
- Then your training code goes as follows:
from imageai.Prediction.Custom import ModelTraining model_trainer = ModelTraining() model_trainer.setModelTypeAsResNet() model_trainer.setDataDirectory("pets") model_trainer.trainModel(num_objects=4, num_experiments=100, enhance_data=True, batch_size=32, show_network_summary=True)
Yes! Just 5 lines of code and you can train any of the available 4 state-of-the-art Deep Learning algorithms on your custom dataset. Now lets take a look at how the code above works.
from imageai.Prediction.Custom import ModelTraining
model_trainer = ModelTraining()
model_trainer.setModelTypeAsResNet()
model_trainer.setDataDirectory("pets")
In the first line, we import the ImageAI model training class, then we define the model trainer in the second line, we set the network type in the third line and set the path to the image dataset we want to train the network on.
model_trainer.trainModel(num_objects=4, num_experiments=100, enhance_data=True, batch_size=32, show_network_summary=True)
In the code above, we start the training process. The parameters stated in the function are as below:
- num_objects : this is to state the number of object types in the image dataset
- num_experiments : this is to state the number of times the network will train over all the training images, which is also called epochs
- enhance_data (optional) : This is used to state if we want the network to produce modified copies of the training images for better performance.
- batch_size : This is to state the number of images the network will process at ones. The images are processed in batches until they are exhausted per each experiment performed.
- show_network_summary : This is to state if the network should show the structure of the training network in the console.
When you start the training, you should see something like this in the console:
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_2 (InputLayer) (None, 224, 224, 3) 0
____________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 112, 112, 64) 9472 input_2[0][0]
____________________________________________________________________________________________________
batch_normalization_1 (BatchNorm (None, 112, 112, 64) 256 conv2d_1[0][0]
____________________________________________________________________________________________________
activation_1 (Activation) (None, 112, 112, 64) 0 batch_normalization_1[0][0]
____________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 55, 55, 64) 0 activation_1[0][0]
____________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 55, 55, 64) 4160 max_pooling2d_1[0][0]
____________________________________________________________________________________________________
batch_normalization_3 (BatchNorm (None, 55, 55, 64) 256 conv2d_3[0][0]
____________________________________________________________________________________________________
activation_2 (Activation) (None, 55, 55, 64) 0 batch_normalization_3[0][0]
____________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 55, 55, 64) 36928 activation_2[0][0]
____________________________________________________________________________________________________
batch_normalization_4 (BatchNorm (None, 55, 55, 64) 256 conv2d_4[0][0]
____________________________________________________________________________________________________
activation_3 (Activation) (None, 55, 55, 64) 0 batch_normalization_4[0][0]
____________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 55, 55, 256) 16640 activation_3[0][0]
____________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 55, 55, 256) 16640 max_pooling2d_1[0][0]
____________________________________________________________________________________________________
batch_normalization_5 (BatchNorm (None, 55, 55, 256) 1024 conv2d_5[0][0]
____________________________________________________________________________________________________
batch_normalization_2 (BatchNorm (None, 55, 55, 256) 1024 conv2d_2[0][0]
____________________________________________________________________________________________________
add_1 (Add) (None, 55, 55, 256) 0 batch_normalization_5[0][0]
batch_normalization_2[0][0]
____________________________________________________________________________________________________
activation_4 (Activation) (None, 55, 55, 256) 0 add_1[0][0]
____________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 55, 55, 64) 16448 activation_4[0][0]
____________________________________________________________________________________________________
batch_normalization_6 (BatchNorm (None, 55, 55, 64) 256 conv2d_6[0][0]
____________________________________________________________________________________________________
activation_5 (Activation) (None, 55, 55, 64) 0 batch_normalization_6[0][0]
____________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 55, 55, 64) 36928 activation_5[0][0]
____________________________________________________________________________________________________
batch_normalization_7 (BatchNorm (None, 55, 55, 64) 256 conv2d_7[0][0]
____________________________________________________________________________________________________
activation_6 (Activation) (None, 55, 55, 64) 0 batch_normalization_7[0][0]
____________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 55, 55, 256) 16640 activation_6[0][0]
____________________________________________________________________________________________________
batch_normalization_8 (BatchNorm (None, 55, 55, 256) 1024 conv2d_8[0][0]
____________________________________________________________________________________________________
add_2 (Add) (None, 55, 55, 256) 0 batch_normalization_8[0][0]
activation_4[0][0]
____________________________________________________________________________________________________
activation_7 (Activation) (None, 55, 55, 256) 0 add_2[0][0]
____________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 55, 55, 64) 16448 activation_7[0][0]
____________________________________________________________________________________________________
batch_normalization_9 (BatchNorm (None, 55, 55, 64) 256 conv2d_9[0][0]
____________________________________________________________________________________________________
activation_8 (Activation) (None, 55, 55, 64) 0 batch_normalization_9[0][0]
____________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 55, 55, 64) 36928 activation_8[0][0]
____________________________________________________________________________________________________
batch_normalization_10 (BatchNor (None, 55, 55, 64) 256 conv2d_10[0][0]
____________________________________________________________________________________________________
activation_9 (Activation) (None, 55, 55, 64) 0 batch_normalization_10[0][0]
____________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 55, 55, 256) 16640 activation_9[0][0]
____________________________________________________________________________________________________
batch_normalization_11 (BatchNor (None, 55, 55, 256) 1024 conv2d_11[0][0]
____________________________________________________________________________________________________
add_3 (Add) (None, 55, 55, 256) 0 batch_normalization_11[0][0]
activation_7[0][0]
____________________________________________________________________________________________________
activation_10 (Activation) (None, 55, 55, 256) 0 add_3[0][0]
____________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 28, 28, 128) 32896 activation_10[0][0]
____________________________________________________________________________________________________
batch_normalization_13 (BatchNor (None, 28, 28, 128) 512 conv2d_13[0][0]
____________________________________________________________________________________________________
activation_11 (Activation) (None, 28, 28, 128) 0 batch_normalization_13[0][0]
____________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 28, 28, 128) 147584 activation_11[0][0]
____________________________________________________________________________________________________
batch_normalization_14 (BatchNor (None, 28, 28, 128) 512 conv2d_14[0][0]
____________________________________________________________________________________________________
activation_12 (Activation) (None, 28, 28, 128) 0 batch_normalization_14[0][0]
____________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 28, 28, 512) 66048 activation_12[0][0]
____________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 28, 28, 512) 131584 activation_10[0][0]
____________________________________________________________________________________________________
batch_normalization_15 (BatchNor (None, 28, 28, 512) 2048 conv2d_15[0][0]
____________________________________________________________________________________________________
batch_normalization_12 (BatchNor (None, 28, 28, 512) 2048 conv2d_12[0][0]
____________________________________________________________________________________________________
add_4 (Add) (None, 28, 28, 512) 0 batch_normalization_15[0][0]
batch_normalization_12[0][0]
____________________________________________________________________________________________________
activation_13 (Activation) (None, 28, 28, 512) 0 add_4[0][0]
____________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 28, 28, 128) 65664 activation_13[0][0]
____________________________________________________________________________________________________
batch_normalization_16 (BatchNor (None, 28, 28, 128) 512 conv2d_16[0][0]
____________________________________________________________________________________________________
activation_14 (Activation) (None, 28, 28, 128) 0 batch_normalization_16[0][0]
____________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 28, 28, 128) 147584 activation_14[0][0]
____________________________________________________________________________________________________
batch_normalization_17 (BatchNor (None, 28, 28, 128) 512 conv2d_17[0][0]
____________________________________________________________________________________________________
activation_15 (Activation) (None, 28, 28, 128) 0 batch_normalization_17[0][0]
____________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 28, 28, 512) 66048 activation_15[0][0]
____________________________________________________________________________________________________
batch_normalization_18 (BatchNor (None, 28, 28, 512) 2048 conv2d_18[0][0]
____________________________________________________________________________________________________
add_5 (Add) (None, 28, 28, 512) 0 batch_normalization_18[0][0]
activation_13[0][0]
____________________________________________________________________________________________________
activation_16 (Activation) (None, 28, 28, 512) 0 add_5[0][0]
____________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 28, 28, 128) 65664 activation_16[0][0]
____________________________________________________________________________________________________
batch_normalization_19 (BatchNor (None, 28, 28, 128) 512 conv2d_19[0][0]
____________________________________________________________________________________________________
activation_17 (Activation) (None, 28, 28, 128) 0 batch_normalization_19[0][0]
____________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 28, 28, 128) 147584 activation_17[0][0]
____________________________________________________________________________________________________
batch_normalization_20 (BatchNor (None, 28, 28, 128) 512 conv2d_20[0][0]
____________________________________________________________________________________________________
activation_18 (Activation) (None, 28, 28, 128) 0 batch_normalization_20[0][0]
____________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 28, 28, 512) 66048 activation_18[0][0]
____________________________________________________________________________________________________
batch_normalization_21 (BatchNor (None, 28, 28, 512) 2048 conv2d_21[0][0]
____________________________________________________________________________________________________
add_6 (Add) (None, 28, 28, 512) 0 batch_normalization_21[0][0]
activation_16[0][0]
____________________________________________________________________________________________________
activation_19 (Activation) (None, 28, 28, 512) 0 add_6[0][0]
____________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, 28, 28, 128) 65664 activation_19[0][0]
____________________________________________________________________________________________________
batch_normalization_22 (BatchNor (None, 28, 28, 128) 512 conv2d_22[0][0]
____________________________________________________________________________________________________
activation_20 (Activation) (None, 28, 28, 128) 0 batch_normalization_22[0][0]
____________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, 28, 28, 128) 147584 activation_20[0][0]
____________________________________________________________________________________________________
batch_normalization_23 (BatchNor (None, 28, 28, 128) 512 conv2d_23[0][0]
____________________________________________________________________________________________________
activation_21 (Activation) (None, 28, 28, 128) 0 batch_normalization_23[0][0]
____________________________________________________________________________________________________
conv2d_24 (Conv2D) (None, 28, 28, 512) 66048 activation_21[0][0]
____________________________________________________________________________________________________
batch_normalization_24 (BatchNor (None, 28, 28, 512) 2048 conv2d_24[0][0]
____________________________________________________________________________________________________
add_7 (Add) (None, 28, 28, 512) 0 batch_normalization_24[0][0]
activation_19[0][0]
____________________________________________________________________________________________________
activation_22 (Activation) (None, 28, 28, 512) 0 add_7[0][0]
____________________________________________________________________________________________________
conv2d_26 (Conv2D) (None, 14, 14, 256) 131328 activation_22[0][0]
____________________________________________________________________________________________________
batch_normalization_26 (BatchNor (None, 14, 14, 256) 1024 conv2d_26[0][0]
____________________________________________________________________________________________________
activation_23 (Activation) (None, 14, 14, 256) 0 batch_normalization_26[0][0]
____________________________________________________________________________________________________
conv2d_27 (Conv2D) (None, 14, 14, 256) 590080 activation_23[0][0]
____________________________________________________________________________________________________
batch_normalization_27 (BatchNor (None, 14, 14, 256) 1024 conv2d_27[0][0]
____________________________________________________________________________________________________
activation_24 (Activation) (None, 14, 14, 256) 0 batch_normalization_27[0][0]
____________________________________________________________________________________________________
conv2d_28 (Conv2D) (None, 14, 14, 1024) 263168 activation_24[0][0]
____________________________________________________________________________________________________
conv2d_25 (Conv2D) (None, 14, 14, 1024) 525312 activation_22[0][0]
____________________________________________________________________________________________________
batch_normalization_28 (BatchNor (None, 14, 14, 1024) 4096 conv2d_28[0][0]
____________________________________________________________________________________________________
batch_normalization_25 (BatchNor (None, 14, 14, 1024) 4096 conv2d_25[0][0]
____________________________________________________________________________________________________
add_8 (Add) (None, 14, 14, 1024) 0 batch_normalization_28[0][0]
batch_normalization_25[0][0]
____________________________________________________________________________________________________
activation_25 (Activation) (None, 14, 14, 1024) 0 add_8[0][0]
____________________________________________________________________________________________________
conv2d_29 (Conv2D) (None, 14, 14, 256) 262400 activation_25[0][0]
____________________________________________________________________________________________________
batch_normalization_29 (BatchNor (None, 14, 14, 256) 1024 conv2d_29[0][0]
____________________________________________________________________________________________________
activation_26 (Activation) (None, 14, 14, 256) 0 batch_normalization_29[0][0]
____________________________________________________________________________________________________
conv2d_30 (Conv2D) (None, 14, 14, 256) 590080 activation_26[0][0]
____________________________________________________________________________________________________
batch_normalization_30 (BatchNor (None, 14, 14, 256) 1024 conv2d_30[0][0]
____________________________________________________________________________________________________
activation_27 (Activation) (None, 14, 14, 256) 0 batch_normalization_30[0][0]
____________________________________________________________________________________________________
conv2d_31 (Conv2D) (None, 14, 14, 1024) 263168 activation_27[0][0]
____________________________________________________________________________________________________
batch_normalization_31 (BatchNor (None, 14, 14, 1024) 4096 conv2d_31[0][0]
____________________________________________________________________________________________________
add_9 (Add) (None, 14, 14, 1024) 0 batch_normalization_31[0][0]
activation_25[0][0]
____________________________________________________________________________________________________
activation_28 (Activation) (None, 14, 14, 1024) 0 add_9[0][0]
____________________________________________________________________________________________________
conv2d_32 (Conv2D) (None, 14, 14, 256) 262400 activation_28[0][0]
____________________________________________________________________________________________________
batch_normalization_32 (BatchNor (None, 14, 14, 256) 1024 conv2d_32[0][0]
____________________________________________________________________________________________________
activation_29 (Activation) (None, 14, 14, 256) 0 batch_normalization_32[0][0]
____________________________________________________________________________________________________
conv2d_33 (Conv2D) (None, 14, 14, 256) 590080 activation_29[0][0]
____________________________________________________________________________________________________
batch_normalization_33 (BatchNor (None, 14, 14, 256) 1024 conv2d_33[0][0]
____________________________________________________________________________________________________
activation_30 (Activation) (None, 14, 14, 256) 0 batch_normalization_33[0][0]
____________________________________________________________________________________________________
conv2d_34 (Conv2D) (None, 14, 14, 1024) 263168 activation_30[0][0]
____________________________________________________________________________________________________
batch_normalization_34 (BatchNor (None, 14, 14, 1024) 4096 conv2d_34[0][0]
____________________________________________________________________________________________________
add_10 (Add) (None, 14, 14, 1024) 0 batch_normalization_34[0][0]
activation_28[0][0]
____________________________________________________________________________________________________
activation_31 (Activation) (None, 14, 14, 1024) 0 add_10[0][0]
____________________________________________________________________________________________________
conv2d_35 (Conv2D) (None, 14, 14, 256) 262400 activation_31[0][0]
____________________________________________________________________________________________________
batch_normalization_35 (BatchNor (None, 14, 14, 256) 1024 conv2d_35[0][0]
____________________________________________________________________________________________________
activation_32 (Activation) (None, 14, 14, 256) 0 batch_normalization_35[0][0]
____________________________________________________________________________________________________
conv2d_36 (Conv2D) (None, 14, 14, 256) 590080 activation_32[0][0]
____________________________________________________________________________________________________
batch_normalization_36 (BatchNor (None, 14, 14, 256) 1024 conv2d_36[0][0]
____________________________________________________________________________________________________
activation_33 (Activation) (None, 14, 14, 256) 0 batch_normalization_36[0][0]
____________________________________________________________________________________________________
conv2d_37 (Conv2D) (None, 14, 14, 1024) 263168 activation_33[0][0]
____________________________________________________________________________________________________
batch_normalization_37 (BatchNor (None, 14, 14, 1024) 4096 conv2d_37[0][0]
____________________________________________________________________________________________________
add_11 (Add) (None, 14, 14, 1024) 0 batch_normalization_37[0][0]
activation_31[0][0]
____________________________________________________________________________________________________
activation_34 (Activation) (None, 14, 14, 1024) 0 add_11[0][0]
____________________________________________________________________________________________________
conv2d_38 (Conv2D) (None, 14, 14, 256) 262400 activation_34[0][0]
____________________________________________________________________________________________________
batch_normalization_38 (BatchNor (None, 14, 14, 256) 1024 conv2d_38[0][0]
____________________________________________________________________________________________________
activation_35 (Activation) (None, 14, 14, 256) 0 batch_normalization_38[0][0]
____________________________________________________________________________________________________
conv2d_39 (Conv2D) (None, 14, 14, 256) 590080 activation_35[0][0]
____________________________________________________________________________________________________
batch_normalization_39 (BatchNor (None, 14, 14, 256) 1024 conv2d_39[0][0]
____________________________________________________________________________________________________
activation_36 (Activation) (None, 14, 14, 256) 0 batch_normalization_39[0][0]
____________________________________________________________________________________________________
conv2d_40 (Conv2D) (None, 14, 14, 1024) 263168 activation_36[0][0]
____________________________________________________________________________________________________
batch_normalization_40 (BatchNor (None, 14, 14, 1024) 4096 conv2d_40[0][0]
____________________________________________________________________________________________________
add_12 (Add) (None, 14, 14, 1024) 0 batch_normalization_40[0][0]
activation_34[0][0]
____________________________________________________________________________________________________
activation_37 (Activation) (None, 14, 14, 1024) 0 add_12[0][0]
____________________________________________________________________________________________________
conv2d_41 (Conv2D) (None, 14, 14, 256) 262400 activation_37[0][0]
____________________________________________________________________________________________________
batch_normalization_41 (BatchNor (None, 14, 14, 256) 1024 conv2d_41[0][0]
____________________________________________________________________________________________________
activation_38 (Activation) (None, 14, 14, 256) 0 batch_normalization_41[0][0]
____________________________________________________________________________________________________
conv2d_42 (Conv2D) (None, 14, 14, 256) 590080 activation_38[0][0]
____________________________________________________________________________________________________
batch_normalization_42 (BatchNor (None, 14, 14, 256) 1024 conv2d_42[0][0]
____________________________________________________________________________________________________
activation_39 (Activation) (None, 14, 14, 256) 0 batch_normalization_42[0][0]
____________________________________________________________________________________________________
conv2d_43 (Conv2D) (None, 14, 14, 1024) 263168 activation_39[0][0]
____________________________________________________________________________________________________
batch_normalization_43 (BatchNor (None, 14, 14, 1024) 4096 conv2d_43[0][0]
____________________________________________________________________________________________________
add_13 (Add) (None, 14, 14, 1024) 0 batch_normalization_43[0][0]
activation_37[0][0]
____________________________________________________________________________________________________
activation_40 (Activation) (None, 14, 14, 1024) 0 add_13[0][0]
____________________________________________________________________________________________________
conv2d_45 (Conv2D) (None, 7, 7, 512) 524800 activation_40[0][0]
____________________________________________________________________________________________________
batch_normalization_45 (BatchNor (None, 7, 7, 512) 2048 conv2d_45[0][0]
____________________________________________________________________________________________________
activation_41 (Activation) (None, 7, 7, 512) 0 batch_normalization_45[0][0]
____________________________________________________________________________________________________
conv2d_46 (Conv2D) (None, 7, 7, 512) 2359808 activation_41[0][0]
____________________________________________________________________________________________________
batch_normalization_46 (BatchNor (None, 7, 7, 512) 2048 conv2d_46[0][0]
____________________________________________________________________________________________________
activation_42 (Activation) (None, 7, 7, 512) 0 batch_normalization_46[0][0]
____________________________________________________________________________________________________
conv2d_47 (Conv2D) (None, 7, 7, 2048) 1050624 activation_42[0][0]
____________________________________________________________________________________________________
conv2d_44 (Conv2D) (None, 7, 7, 2048) 2099200 activation_40[0][0]
____________________________________________________________________________________________________
batch_normalization_47 (BatchNor (None, 7, 7, 2048) 8192 conv2d_47[0][0]
____________________________________________________________________________________________________
batch_normalization_44 (BatchNor (None, 7, 7, 2048) 8192 conv2d_44[0][0]
____________________________________________________________________________________________________
add_14 (Add) (None, 7, 7, 2048) 0 batch_normalization_47[0][0]
batch_normalization_44[0][0]
____________________________________________________________________________________________________
activation_43 (Activation) (None, 7, 7, 2048) 0 add_14[0][0]
____________________________________________________________________________________________________
conv2d_48 (Conv2D) (None, 7, 7, 512) 1049088 activation_43[0][0]
____________________________________________________________________________________________________
batch_normalization_48 (BatchNor (None, 7, 7, 512) 2048 conv2d_48[0][0]
____________________________________________________________________________________________________
activation_44 (Activation) (None, 7, 7, 512) 0 batch_normalization_48[0][0]
____________________________________________________________________________________________________
conv2d_49 (Conv2D) (None, 7, 7, 512) 2359808 activation_44[0][0]
____________________________________________________________________________________________________
batch_normalization_49 (BatchNor (None, 7, 7, 512) 2048 conv2d_49[0][0]
____________________________________________________________________________________________________
activation_45 (Activation) (None, 7, 7, 512) 0 batch_normalization_49[0][0]
____________________________________________________________________________________________________
conv2d_50 (Conv2D) (None, 7, 7, 2048) 1050624 activation_45[0][0]
____________________________________________________________________________________________________
batch_normalization_50 (BatchNor (None, 7, 7, 2048) 8192 conv2d_50[0][0]
____________________________________________________________________________________________________
add_15 (Add) (None, 7, 7, 2048) 0 batch_normalization_50[0][0]
activation_43[0][0]
____________________________________________________________________________________________________
activation_46 (Activation) (None, 7, 7, 2048) 0 add_15[0][0]
____________________________________________________________________________________________________
conv2d_51 (Conv2D) (None, 7, 7, 512) 1049088 activation_46[0][0]
____________________________________________________________________________________________________
batch_normalization_51 (BatchNor (None, 7, 7, 512) 2048 conv2d_51[0][0]
____________________________________________________________________________________________________
activation_47 (Activation) (None, 7, 7, 512) 0 batch_normalization_51[0][0]
____________________________________________________________________________________________________
conv2d_52 (Conv2D) (None, 7, 7, 512) 2359808 activation_47[0][0]
____________________________________________________________________________________________________
batch_normalization_52 (BatchNor (None, 7, 7, 512) 2048 conv2d_52[0][0]
____________________________________________________________________________________________________
activation_48 (Activation) (None, 7, 7, 512) 0 batch_normalization_52[0][0]
____________________________________________________________________________________________________
conv2d_53 (Conv2D) (None, 7, 7, 2048) 1050624 activation_48[0][0]
____________________________________________________________________________________________________
batch_normalization_53 (BatchNor (None, 7, 7, 2048) 8192 conv2d_53[0][0]
____________________________________________________________________________________________________
add_16 (Add) (None, 7, 7, 2048) 0 batch_normalization_53[0][0]
activation_46[0][0]
____________________________________________________________________________________________________
activation_49 (Activation) (None, 7, 7, 2048) 0 add_16[0][0]
____________________________________________________________________________________________________
global_avg_pooling (GlobalAverag (None, 2048) 0 activation_49[0][0]
____________________________________________________________________________________________________
dense_1 (Dense) (None, 10) 20490 global_avg_pooling[0][0]
____________________________________________________________________________________________________
activation_50 (Activation) (None, 10) 0 dense_1[0][0]
====================================================================================================
Total params: 23,608,202
Trainable params: 23,555,082
Non-trainable params: 53,120
____________________________________________________________________________________________________
Using Enhanced Data Generation
Found 4000 images belonging to 4 classes.
Found 800 images belonging to 4 classes.
JSON Mapping for the model classes saved to C:\Users\User\PycharmProjects\ImageAITest\pets\json\model_class.json
Number of experiments (Epochs) : 100
When the training progress progresses, you will see results as follows in the console:
Epoch 1/100
1/25 [>.............................] - ETA: 52s - loss: 2.3026 - acc: 0.2500
2/25 [=>............................] - ETA: 41s - loss: 2.3027 - acc: 0.1250
3/25 [==>...........................] - ETA: 37s - loss: 2.2961 - acc: 0.1667
4/25 [===>..........................] - ETA: 36s - loss: 2.2980 - acc: 0.1250
5/25 [=====>........................] - ETA: 33s - loss: 2.3178 - acc: 0.1000
6/25 [======>.......................] - ETA: 31s - loss: 2.3214 - acc: 0.0833
7/25 [=======>......................] - ETA: 30s - loss: 2.3202 - acc: 0.0714
8/25 [========>.....................] - ETA: 29s - loss: 2.3207 - acc: 0.0625
9/25 [=========>....................] - ETA: 27s - loss: 2.3191 - acc: 0.0556
10/25 [===========>..................] - ETA: 25s - loss: 2.3167 - acc: 0.0750
11/25 [============>.................] - ETA: 23s - loss: 2.3162 - acc: 0.0682
12/25 [=============>................] - ETA: 21s - loss: 2.3143 - acc: 0.0833
13/25 [==============>...............] - ETA: 20s - loss: 2.3135 - acc: 0.0769
14/25 [===============>..............] - ETA: 18s - loss: 2.3132 - acc: 0.0714
15/25 [=================>............] - ETA: 16s - loss: 2.3128 - acc: 0.0667
16/25 [==================>...........] - ETA: 15s - loss: 2.3121 - acc: 0.0781
17/25 [===================>..........] - ETA: 13s - loss: 2.3116 - acc: 0.0735
18/25 [====================>.........] - ETA: 12s - loss: 2.3114 - acc: 0.0694
19/25 [=====================>........] - ETA: 10s - loss: 2.3112 - acc: 0.0658
20/25 [=======================>......] - ETA: 8s - loss: 2.3109 - acc: 0.0625
21/25 [========================>.....] - ETA: 7s - loss: 2.3107 - acc: 0.0595
22/25 [=========================>....] - ETA: 5s - loss: 2.3104 - acc: 0.0568
23/25 [==========================>...] - ETA: 3s - loss: 2.3101 - acc: 0.0543
24/25 [===========================>..] - ETA: 1s - loss: 2.3097 - acc: 0.0625Epoch 00000: saving model to C:\Users\Moses\Documents\Moses\W7\AI\Custom Datasets\IDENPROF\idenprof-small-test\idenprof\models\model_ex-000_acc-0.100000.h5
25/25 [==============================] - 51s - loss: 2.3095 - acc: 0.0600 - val_loss: 2.3026 - val_acc: 0.1000
Let us explain the details shown above:
- The line Epoch 1/100 means the network is training the first experiment of the targeted 100
- The line
1/25 [>.............................] - ETA: 52s - loss: 2.3026 - acc: 0.2500
represents the number of batches that has been trained in the present experiment - The line
Epoch 00000: saving model to C:\Users\User\PycharmProjects\ImageAITest\pets\models\model_ex-000_acc-0.100000.h5
refers to the model saved after the present experiment. The ex_000 represents the experiment at this stage while the acc_0.100000 and val_acc: 0.1000 represents the accuracy of the model on the test images after the present experiment (maximum value value of accuracy is 1.0). This result helps to know the best performed model you can use for custom image prediction.
Once you are done training your custom model, you can use the "CustomImagePrediction" class to perform image prediction with your model. Simply follow the link below. imageai/Prediction/CUSTOMPREDICTION.md
ImageAI now allows you to your custom model in full during training, which ensures you can perform custom prediction without necessarily specifying the network type.
All you need to do is set the paramater save_full_model
to True
in your trainModel()
function.
See an example code below.
from imageai.Prediction.Custom import ModelTraining
import os
trainer = ModelTraining()
trainer.setModelTypeAsDenseNet()
trainer.setDataDirectory("idenprof")
trainer.trainModel(num_objects=10, num_experiments=50, enhance_data=True, batch_size=16, show_network_summary=True, save_full_model=True)
A sample from the IdenProf Dataset used to train a Model for predicting professionals.
Below we provide a sample code to train on IdenProf, a dataset which contains images of 10 uniformed professionals. The code below will download the dataset and initiate the training:
from io import open
import requests
import shutil
from zipfile import ZipFile
import os
from imageai.Prediction.Custom import ModelTraining
execution_path = os.getcwd()
TRAIN_ZIP_ONE = os.path.join(execution_path, "idenprof-train1.zip")
TRAIN_ZIP_TWO = os.path.join(execution_path, "idenprof-train2.zip")
TEST_ZIP = os.path.join(execution_path, "idenprof-test.zip")
DATASET_DIR = os.path.join(execution_path, "idenprof")
DATASET_TRAIN_DIR = os.path.join(DATASET_DIR, "train")
DATASET_TEST_DIR = os.path.join(DATASET_DIR, "test")
if(os.path.exists(DATASET_DIR) == False):
os.mkdir(DATASET_DIR)
if(os.path.exists(DATASET_TRAIN_DIR) == False):
os.mkdir(DATASET_TRAIN_DIR)
if(os.path.exists(DATASET_TEST_DIR) == False):
os.mkdir(DATASET_TEST_DIR)
if(len(os.listdir(DATASET_TRAIN_DIR)) < 10):
if(os.path.exists(TRAIN_ZIP_ONE) == False):
print("Downloading idenprof-train1.zip")
data = requests.get("https://github.com/OlafenwaMoses/IdenProf/releases/download/v1.0/idenprof-train1.zip", stream = True)
with open(TRAIN_ZIP_ONE, "wb") as file:
shutil.copyfileobj(data.raw, file)
del data
if (os.path.exists(TRAIN_ZIP_TWO) == False):
print("Downloading idenprof-train2.zip")
data = requests.get("https://github.com/OlafenwaMoses/IdenProf/releases/download/v1.0/idenprof-train2.zip", stream=True)
with open(TRAIN_ZIP_TWO, "wb") as file:
shutil.copyfileobj(data.raw, file)
del data
print("Extracting idenprof-train1.zip")
extract1 = ZipFile(TRAIN_ZIP_ONE)
extract1.extractall(DATASET_TRAIN_DIR)
extract1.close()
print("Extracting idenprof-train2.zip")
extract2 = ZipFile(TRAIN_ZIP_TWO)
extract2.extractall(DATASET_TRAIN_DIR)
extract2.close()
if(len(os.listdir(DATASET_TEST_DIR)) < 10):
if (os.path.exists(TEST_ZIP) == False):
print("Downloading idenprof-test.zip")
data = requests.get("https://github.com/OlafenwaMoses/IdenProf/releases/download/v1.0/idenprof-test.zip", stream=True)
with open(TEST_ZIP, "wb") as file:
shutil.copyfileobj(data.raw, file)
del data
print("Extracting idenprof-test.zip")
extract = ZipFile(TEST_ZIP)
extract.extractall(DATASET_TEST_DIR)
extract.close()
model_trainer = ModelTraining()
model_trainer.setModelTypeAsResNet()
model_trainer.setDataDirectory(DATASET_DIR)
model_trainer.trainModel(num_objects=10, num_experiments=100, enhance_data=True, batch_size=32, show_network_summary=True)
ImageAI now allows you to continue training your custom model on your previously saved model.
This is useful in cases of incomplete training due compute time limits/large size of dataset or should you intend to further train your model.
Kindly note that continuous training is for using a previously saved model to train on the same dataset the model was trained on.
All you need to do is specify the continue_from_model
parameter to the path of the previously saved model in your trainModel()
function.
See an example code below.
from imageai.Prediction.Custom import ModelTraining
import os
trainer = ModelTraining()
trainer.setModelTypeAsDenseNet()
trainer.setDataDirectory("idenprof")
trainer.trainModel(num_objects=10, num_experiments=50, enhance_data=True, batch_size=8, show_network_summary=True, continue_from_model="idenprof_densenet-0.763500.h5")
From the feedbacks we have received over the past months, we discovered most custom models trained with ImageAI were based on datasets with few number of images as they fall short the minimum recommendation of 500 images per each class of objects, for a achieving a viable accuracy.
To ensure they can still train very accurate custom models using few number of images, ImageAI now allows you to train by leveraging transfer learning . This means you can take any pre-trained ResNet50, Squeezenet, InceptionV3 and DenseNet121 model trained on larger datasets and use it to kickstart your custom model training.
All you need to do is specify the transfer_from_model
parameter to the path of the pre-trained model, initial_num_objects
parameter which corresponds to the number of objects in the previous dataset the pre-trained model was trained on, all in your trainModel()
function. See an example code below, showing how to perform transfer learning from a ResNet50 model trained on the ImageNet dataset.
from imageai.Prediction.Custom import ModelTraining
import os
trainer = ModelTraining()
trainer.setModelTypeAsResNet()
trainer.setDataDirectory("idenprof")
trainer.trainModel(num_objects=10, num_experiments=50, enhance_data=True, batch_size=32, show_network_summary=True,transfer_from_model="resnet50_weights_tf_dim_ordering_tf_kernels.h5", initial_num_objects=1000)
We are providing an opportunity for anyone that uses to train a model to submit the model and its JSON mapping file and have it listed in this repository. Reach to the details below should intend to share your trained model in this repository.
- Moses Olafenwa
- Email: [email protected]
- Website: https://moses.aicommons.science
- Twitter: @OlafenwaMoses
- Medium: @guymodscientist
- Facebook: moses.olafenwa
We have provided full documentation for all ImageAI classes and functions in 3 major languages. Find links below:
- Documentation - English Version https://imageai.readthedocs.io
- Documentation - Chinese Version https://imageai-cn.readthedocs.io
- Documentation - French Version https://imageai-fr.readthedocs.io