forked from qqwweee/keras-yolo3
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add coco_annotation.py to convert the 'instances_train2017.json' (mscoco2017 dataset) to 'train.txt' for training * add kmeans for train.txt to calculate anchors on custom dataset, close qqwweee#68, close qqwweee#71
- Loading branch information
Showing
2 changed files
with
153 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import json | ||
from collections import defaultdict | ||
|
||
name_box_id = defaultdict(list) | ||
id_name = dict() | ||
f = open( | ||
"mscoco2017/annotations/instances_train2017.json", | ||
encoding='utf-8') | ||
data = json.load(f) | ||
|
||
annotations = data['annotations'] | ||
for ant in annotations: | ||
id = ant['image_id'] | ||
name = 'mscoco2017/train2017/%012d.jpg' % id | ||
cat = ant['category_id'] | ||
|
||
if cat >= 1 and cat <= 11: | ||
cat = cat - 1 | ||
elif cat >= 13 and cat <= 25: | ||
cat = cat - 2 | ||
elif cat >= 27 and cat <= 28: | ||
cat = cat - 3 | ||
elif cat >= 31 and cat <= 44: | ||
cat = cat - 5 | ||
elif cat >= 46 and cat <= 65: | ||
cat = cat - 6 | ||
elif cat == 67: | ||
cat = cat - 7 | ||
elif cat == 70: | ||
cat = cat - 9 | ||
elif cat >= 72 and cat <= 82: | ||
cat = cat - 10 | ||
elif cat >= 84 and cat <= 90: | ||
cat = cat - 11 | ||
|
||
name_box_id[name].append([ant['bbox'], cat]) | ||
|
||
f = open('train.txt', 'w') | ||
for key in name_box_id.keys(): | ||
f.write(key) | ||
box_infos = name_box_id[key] | ||
for info in box_infos: | ||
x_min = int(info[0][0]) | ||
y_min = int(info[0][1]) | ||
x_max = x_min + int(info[0][2]) | ||
y_max = y_min + int(info[0][3]) | ||
|
||
box_info = " %d,%d,%d,%d,%d" % ( | ||
x_min, y_min, x_max, y_max, int(info[1])) | ||
f.write(box_info) | ||
f.write('\n') | ||
f.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import numpy as np | ||
|
||
|
||
class YOLO_Kmeans: | ||
|
||
def __init__(self, cluster_number, filename): | ||
self.cluster_number = 9 | ||
self.filename = "2012_train.txt" | ||
|
||
def iou(self, boxes, clusters): # 1 box -> k clusters | ||
n = boxes.shape[0] | ||
k = cluster_number | ||
|
||
box_area = boxes[:, 0] * boxes[:, 1] | ||
box_area = box_area.repeat(k) | ||
box_area = np.reshape(box_area, (n, k)) | ||
|
||
cluster_area = clusters[:, 0] * clusters[:, 1] | ||
cluster_area = np.tile(cluster_area, [1, n]) | ||
cluster_area = np.reshape(cluster_area, (n, k)) | ||
|
||
box_w_matrix = np.reshape(boxes[:, 0].repeat(k), (n, k)) | ||
cluster_w_matrix = np.reshape(np.tile(clusters[:, 0], (1, n)), (n, k)) | ||
min_w_matrix = np.minimum(cluster_w_matrix, box_w_matrix) | ||
|
||
box_h_matrix = np.reshape(boxes[:, 1].repeat(k), (n, k)) | ||
cluster_h_matrix = np.reshape(np.tile(clusters[:, 1], (1, n)), (n, k)) | ||
min_h_matrix = np.minimum(cluster_h_matrix, box_h_matrix) | ||
inter_area = np.multiply(min_w_matrix, min_h_matrix) | ||
|
||
result = inter_area / (box_area + cluster_area - inter_area) | ||
return result | ||
|
||
def avg_iou(self, boxes, clusters): | ||
accuracy = np.mean([np.max(self.iou(boxes, clusters), axis=1)]) | ||
return accuracy | ||
|
||
def kmeans(self, boxes, k, dist=np.median): | ||
box_number = boxes.shape[0] | ||
distances = np.empty((box_number, k)) | ||
last_nearest = np.zeros((box_number,)) | ||
np.random.seed() | ||
clusters = boxes[np.random.choice( | ||
box_number, k, replace=False)] # init k clusters | ||
while True: | ||
|
||
distances = 1 - self.iou(boxes, clusters) | ||
|
||
current_nearest = np.argmin(distances, axis=1) | ||
if (last_nearest == current_nearest).all(): | ||
break # clusters won't change | ||
for cluster in range(k): | ||
clusters[cluster] = dist( # update clusters | ||
boxes[current_nearest == cluster], axis=0) | ||
|
||
last_nearest = current_nearest | ||
|
||
return clusters | ||
|
||
def result2txt(self, data): | ||
f = open("yolo_anchors.txt", 'w') | ||
row = np.shape(data)[0] | ||
for i in range(row): | ||
if i == 0: | ||
x_y = "%d,%d" % (data[i][0], data[i][1]) | ||
else: | ||
x_y = ", %d,%d" % (data[i][0], data[i][1]) | ||
f.write(x_y) | ||
f.close() | ||
|
||
def txt2boxes(self): | ||
f = open(self.filename, 'r') | ||
dataSet = [] | ||
for line in f: | ||
infos = line.split(" ") | ||
length = len(infos) | ||
for i in range(1, length): | ||
width = int(infos[i].split(",")[2]) - \ | ||
int(infos[i].split(",")[0]) | ||
height = int(infos[i].split(",")[3]) - \ | ||
int(infos[i].split(",")[1]) | ||
dataSet.append([width, height]) | ||
result = np.array(dataSet) | ||
f.close() | ||
return result | ||
|
||
def txt2clusters(self): | ||
all_boxes = self.txt2boxes() | ||
result = self.kmeans(all_boxes, k=self.cluster_number) | ||
result = result[np.lexsort(result.T[0, None])] | ||
self.result2txt(result) | ||
print("K anchors:\n {}".format(result)) | ||
print("Accuracy: {:.2f}%".format( | ||
self.avg_iou(all_boxes, result) * 100)) | ||
|
||
|
||
if __name__ == "__main__": | ||
cluster_number = 9 | ||
filename = "2012_train.txt" | ||
kmeans = YOLO_Kmeans(cluster_number, filename) | ||
kmeans.txt2clusters() |