Skip to content

Commit

Permalink
Implement Kitti generator (fizyr#288)
Browse files Browse the repository at this point in the history
* Code styling

* Code styling

* Rebased on origin master

* Fix pep8 violations

* Added missing kitti_parser in the train script (it dissapered after rebasing)
  • Loading branch information
lvaleriu authored and hgaiser committed Mar 15, 2018
1 parent 08421ca commit be03c24
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 1 deletion.
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,21 @@ retinanet-train oid /path/to/OID
keras_retinanet/bin/train.py oid /path/to/OID --labels_filter=Helmet,Tree
```

For training on a custom dataset, a CSV file can be used as a way to pass the data.

For training on [KITTI](http://www.cvlibs.net/datasets/kitti/eval_object.php), run:
```shell
# Running directly from the repository:
keras_retinanet/bin/train.py kitti /path/to/KITTI

# Using the installed script:
retinanet-train kitti /path/to/KITTI

If you want to prepare the dataset you can use the following script:
https://github.com/NVIDIA/DIGITS/blob/master/examples/object-detection/prepare_kitti_data.py
```


For training on a [custom dataset], a CSV file can be used as a way to pass the data.
See below for more details on the format of these CSV files.
To train using your CSV, run:
```shell
Expand Down
11 changes: 11 additions & 0 deletions keras_retinanet/bin/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
# Change these to absolute imports if you copy this script outside the keras_retinanet package.
from ..preprocessing.pascal_voc import PascalVocGenerator
from ..preprocessing.csv_generator import CSVGenerator
from ..preprocessing.kitti import KittiGenerator
from ..preprocessing.open_images import OpenImagesGenerator
from ..utils.transform import random_transform_generator
from ..utils.visualization import draw_annotations, draw_boxes
Expand Down Expand Up @@ -82,6 +83,12 @@ def create_generator(args):
annotation_cache_dir=args.annotation_cache_dir,
transform_generator=transform_generator
)
elif args.dataset_type == 'kitti':
generator = KittiGenerator(
args.kitti_path,
subset=args.subset,
transform_generator=transform_generator
)
else:
raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

Expand All @@ -101,6 +108,10 @@ def parse_args(args):
pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).')
pascal_parser.add_argument('--pascal-set', help='Name of the set to show (defaults to test).', default='test')

kitti_parser = subparsers.add_parser('kitti')
kitti_parser.add_argument('kitti_path', help='Path to dataset directory (ie. /tmp/kitti).')
kitti_parser.add_argument('subset', help='Argument for loading a subset from train/val.')

def csv_list(string):
return string.split(',')

Expand Down
20 changes: 20 additions & 0 deletions keras_retinanet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..callbacks.eval import Evaluate
from ..preprocessing.pascal_voc import PascalVocGenerator
from ..preprocessing.csv_generator import CSVGenerator
from ..preprocessing.kitti import KittiGenerator
from ..preprocessing.open_images import OpenImagesGenerator
from ..utils.transform import random_transform_generator
from ..utils.keras_version import check_keras_version
Expand Down Expand Up @@ -220,6 +221,22 @@ def create_generators(args):
)
else:
validation_generator = None
elif args.dataset_type == 'kitti':
train_generator = KittiGenerator(
args.kitti_path,
subset='train',
transform_generator=transform_generator,
batch_size=args.batch_size
)

if args.val_annotations:
validation_generator = KittiGenerator(
args.kitti_path,
subset='val',
batch_size=args.batch_size
)
else:
validation_generator = None
else:
raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

Expand Down Expand Up @@ -272,6 +289,9 @@ def parse_args(args):
pascal_parser = subparsers.add_parser('pascal')
pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).')

kitti_parser = subparsers.add_parser('kitti')
kitti_parser.add_argument('kitti_path', help='Path to dataset directory (ie. /tmp/kitti).')

def csv_list(string):
return string.split(',')

Expand Down
126 changes: 126 additions & 0 deletions keras_retinanet/preprocessing/kitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import csv
import os.path

import numpy as np
from PIL import Image

from .generator import Generator
from ..utils.image import read_image_bgr

kitti_classes = {
'Car': 0,
'Van': 1,
'Truck': 2,
'Pedestrian': 3,
'Person_sitting': 4,
'Cyclist': 5,
'Tram': 6,
'Misc': 7,
'DontCare': 7
}


class KittiGenerator(Generator):
def __init__(
self,
base_dir,
subset='train',
**kwargs
):
self.base_dir = base_dir

label_dir = os.path.join(self.base_dir, subset, 'labels')
image_dir = os.path.join(self.base_dir, subset, 'images')

"""
1 type Describes the type of object: 'Car', 'Van', 'Truck',
'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram',
'Misc' or 'DontCare'
1 truncated Float from 0 (non-truncated) to 1 (truncated), where
truncated refers to the object leaving image boundaries
1 occluded Integer (0,1,2,3) indicating occlusion state:
0 = fully visible, 1 = partly occluded
2 = largely occluded, 3 = unknown
1 alpha Observation angle of object, ranging [-pi..pi]
4 bbox 2D bounding box of object in the image (0-based index):
contains left, top, right, bottom pixel coordinates
3 dimensions 3D object dimensions: height, width, length (in meters)
3 location 3D object location x,y,z in camera coordinates (in meters)
1 rotation_y Rotation ry around Y-axis in camera coordinates [-pi..pi]
"""

self.id_to_labels = {}
for label, id in kitti_classes.items():
self.id_to_labels[id] = label

self.image_data = dict()
self.images = []
for i, fn in enumerate(os.listdir(label_dir)):
label_fp = os.path.join(label_dir, fn)
image_fp = os.path.join(image_dir, fn.replace('.txt', '.png'))

self.images.append(image_fp)

fieldnames = ['type', 'truncated', 'occluded', 'alpha', 'left', 'top', 'right', 'bottom', 'dh', 'dw', 'dl',
'lx', 'ly', 'lz', 'ry']
with open(label_fp, 'r') as csv_file:
reader = csv.DictReader(csv_file, delimiter=' ', fieldnames=fieldnames)
boxes = []
for line, row in enumerate(reader):
label = row['type']
cls_id = kitti_classes[label]

annotation = {'cls_id': cls_id, 'x1': row['left'], 'x2': row['right'], 'y2': row['bottom'], 'y1': row['top']}
boxes.append(annotation)

self.image_data[i] = boxes

super(KittiGenerator, self).__init__(**kwargs)

def size(self):
return len(self.images)

def num_classes(self):
return max(kitti_classes.values()) + 1

def name_to_label(self, name):
raise NotImplementedError()

def label_to_name(self, label):
return self.id_to_labels[label]

def image_aspect_ratio(self, image_index):
# PIL is fast for metadata
image = Image.open(self.images[image_index])
return float(image.width) / float(image.height)

def load_image(self, image_index):
return read_image_bgr(self.images[image_index])

def load_annotations(self, image_index):
annotations = self.image_data[image_index]

boxes = np.zeros((len(annotations), 5))
for idx, ann in enumerate(annotations):
boxes[idx, 0] = float(ann['x1'])
boxes[idx, 1] = float(ann['y1'])
boxes[idx, 2] = float(ann['x2'])
boxes[idx, 3] = float(ann['y2'])
boxes[idx, 4] = int(ann['cls_id'])
return boxes

0 comments on commit be03c24

Please sign in to comment.