Skip to content

Commit

Permalink
Merge pull request OlafenwaMoses#2 from rola93/multi_gpu_and_TFBoard
Browse files Browse the repository at this point in the history
Add support for TensorBoard and allow multiple gpus
  • Loading branch information
rola93 authored Aug 19, 2019
2 parents 119ac64 + b4e533e commit dd6f91d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 39 deletions.
66 changes: 34 additions & 32 deletions imageai/Detection/Custom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,21 @@
from imageai.Detection.YOLOv3.models import yolo_main
from imageai.Detection.Custom.generator import BatchGenerator
from imageai.Detection.Custom.utils.utils import normalize, evaluate, makedirs
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.callbacks import ReduceLROnPlateau
from keras.optimizers import Adam
from imageai.Detection.Custom.callbacks import CustomModelCheckpoint, CustomTensorBoard
from imageai.Detection.Custom.utils.multi_gpu_model import multi_gpu_model
from imageai.Detection.Custom.gen_anchors import generateAnchors
import tensorflow as tf
import keras
from keras.preprocessing.image import load_img, img_to_array
from keras.preprocessing.image import img_to_array
from keras.models import load_model, Input
from keras.callbacks import TensorBoard
from PIL import Image
import matplotlib.image as pltimage
import cv2

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"



class DetectionModelTrainer:

"""
Expand Down Expand Up @@ -54,23 +52,21 @@ def __init__(self):
self.__train_warmup_epochs = 3
self.__train_ignore_treshold = 0.5
self.__train_gpus = "0,1"
self.__train_grid_scales = [1,1,1]
self.__train_grid_scales = [1, 1, 1]
self.__train_obj_scale = 5
self.__train_noobj_scale = 1
self.__train_xywh_scale = 1
self.__train_class_scale = 1
self.__model_directory = ""
self.__train_weights_name = ""
self.__train_debug = True
self.__logs_directory = ""

self.__validation_images_folder = ""
self.__validation_annotations_folder = ""
self.__validation_cache_file = ""
self.__validation_times = 1




def setModelTypeAsYOLOv3(self):
"""
'setModelTypeAsYOLOv3()' is used to set the model type to the YOLOv3 model
Expand Down Expand Up @@ -131,11 +127,31 @@ def setDataDirectory(self, data_directory):
if os.path.exists(os.path.join(data_directory, "json")) == False:
os.makedirs(os.path.join(data_directory, "json"))

if os.path.exists(os.path.join(data_directory, "logs")) == False:
os.makedirs(os.path.join(data_directory, "logs"))

self.__model_directory = os.path.join(data_directory, "models")
self.__train_weights_name = os.path.join(self.__model_directory, "detection_model-")
self.__json_directory = os.path.join(data_directory, "json")
self.__logs_directory = os.path.join(data_directory, "logs")


def setGpuUsage(self, train_gpus):
"""
'setGpuUsage' function allows you to set the GPUs to be used while training
train_gpu can be:
- an integer, indicating the number of GPUs to use
- a list of integers, indicating the id of the GPUs to be used
- a string, indicating the it og the id of the GPUs to be used, separated by commas
:param train_gpus: gpus where to run
:return:
"""
# train_gpus, could be a string separated by comma, or a list of int or the number of GPUs to be used
if type(train_gpus) == str:
train_gpus = train_gpus.split(',')
if type(train_gpus) == int:
train_gpus = range(train_gpus)
# let it as a string separated by commas
self.__train_gpus = ','.join([str(gpu) for gpu in train_gpus])

def setTrainConfig(self, object_names_array, batch_size= 4, num_experiments=100, train_from_pretrained_model=""):

Expand All @@ -155,7 +171,6 @@ def setTrainConfig(self, object_names_array, batch_size= 4, num_experiments=100
:return:
"""


self.__model_anchors, self.__inference_anchors = generateAnchors(self.__train_annotations_folder,
self.__train_images_folder,
self.__train_cache_file, self.__model_labels)
Expand Down Expand Up @@ -249,7 +264,7 @@ def trainModel(self):
warmup_batches = self.__train_warmup_epochs * (self.__train_times * len(train_generator))

os.environ['CUDA_VISIBLE_DEVICES'] = self.__train_gpus
multi_gpu = len(self.__train_gpus.split(','))
multi_gpu = [int(gpu) for gpu in self.__train_gpus.split(',')]

train_model, infer_model = self._create_model(
nb_class=len(labels),
Expand Down Expand Up @@ -375,7 +390,7 @@ def evaluateModel(self, model_path, json_path, batch_size=4, iou_threshold=0.5,
norm=normalize
)

multi_gpu = len(self.__train_gpus.split(','))
multi_gpu = [int(gpu) for gpu in self.__train_gpus.split(',')]
warmup_batches = self.__train_warmup_epochs * (self.__train_times * len(train_generator))

train_model, infer_model = self._create_model(
Expand All @@ -395,7 +410,6 @@ def evaluateModel(self, model_path, json_path, batch_size=4, iou_threshold=0.5,
class_scale=self.__train_class_scale,
)


if(os.path.isfile(model_path)):
if(str(model_path).endswith(".h5")):
try:
Expand All @@ -421,7 +435,6 @@ def evaluateModel(self, model_path, json_path, batch_size=4, iou_threshold=0.5,
except:
None


elif(os.path.isdir(model_path)):
model_files = os.listdir(model_path)

Expand Down Expand Up @@ -449,11 +462,6 @@ def evaluateModel(self, model_path, json_path, batch_size=4, iou_threshold=0.5,
except:
continue






def _create_training_instances(self,
train_annot_folder,
train_image_folder,
Expand Down Expand Up @@ -522,8 +530,10 @@ def _create_callbacks(self, saved_weights_name, model_to_save):
cooldown=0,
min_lr=0
)

return [checkpoint, reduce_on_plateau]
tensor_board = TensorBoard(
log_dir=self.__logs_directory
)
return [checkpoint, reduce_on_plateau, tensor_board]

def _create_model(
self,
Expand All @@ -541,7 +551,7 @@ def _create_model(
xywh_scale,
class_scale
):
if multi_gpu > 1:
if len(multi_gpu) > 1:
with tf.device('/cpu:0'):
template_model, infer_model = create_yolov3_model(
nb_class=nb_class,
Expand Down Expand Up @@ -583,9 +593,7 @@ def _create_model(
print("Pre-trained Model not provided. Transfer learning not in use.")
print("Training will start with 3 warmup experiments")



if multi_gpu > 1:
if len(multi_gpu) > 1:
train_model = multi_gpu_model(template_model, gpus=multi_gpu)
else:
train_model = template_model
Expand Down Expand Up @@ -1383,9 +1391,3 @@ def draw_boxes_and_caption(self, image_frame, v_boxes, v_labels, v_scores, show_
cv2.putText(image_frame, label, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 2)

return image_frame






12 changes: 5 additions & 7 deletions imageai/Detection/Custom/utils/multi_gpu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from keras.models import Model
import tensorflow as tf


def multi_gpu_model(model, gpus):
if isinstance(gpus, (list, tuple)):
num_gpus = len(gpus)
Expand Down Expand Up @@ -37,10 +38,8 @@ def get_slice(data, i, parts):
# Retrieve a slice of the input.
for x in model.inputs:
input_shape = tuple(x.get_shape().as_list())[1:]
slice_i = Lambda(get_slice,
output_shape=input_shape,
arguments={'i': i,
'parts': num_gpus})(x)
slice_i = Lambda(get_slice, output_shape=input_shape,
arguments={'i': i, 'parts': num_gpus})(x)
inputs.append(slice_i)

# Apply model on slice
Expand All @@ -57,6 +56,5 @@ def get_slice(data, i, parts):
with tf.device('/cpu:0'):
merged = []
for name, outputs in zip(model.output_names, all_outputs):
merged.append(concatenate(outputs,
axis=0, name=name))
return Model(model.inputs, merged)
merged.append(concatenate(outputs, axis=0, name=name))
return Model(model.inputs, merged)

0 comments on commit dd6f91d

Please sign in to comment.