diff --git a/tools/check_gt_cord.py b/tools/check_gt_cord.py index 8cd60b4..7b8aa51 100644 --- a/tools/check_gt_cord.py +++ b/tools/check_gt_cord.py @@ -6,6 +6,7 @@ import sys import cv2 import glob +import shutil import pandas as pd import numpy as np from tqdm import tqdm @@ -30,6 +31,10 @@ def parse_xml(file): box_all = [] pts = ['xmin', 'ymin', 'xmax', 'ymax'] + # size + location = xml.find('location') + width = int(location.find('xmax').text) - int(location.find('xmin').text) + # bounding boxes for obj in xml.iter('object'): bbox = obj.find('bndbox') @@ -39,7 +44,7 @@ def parse_xml(file): cur_pt = int(bbox.find(pt).text) - 1 bndbox.append(cur_pt) box_all += [bndbox] - return box_all + return box_all, width def get_box_label(label_df, im_name): boxes = [] @@ -58,28 +63,56 @@ def get_box_label(label_df, im_name): def _boxvis(img, gt_box_list): img1 = img.copy() for box in gt_box_list: - cv2.rectangle(img1, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 1) + cv2.rectangle(img1, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2) plt.subplot(1, 1, 1); plt.imshow(img1[:, :, [2,1,0]]) plt.show() - cv2.waitKey() - + cv2.waitKey(0) -def main(): - # df = pd.read_csv(src_annotation) +def _originvis(name, bbox): + img = cv2.imread(os.path.join(src_traindir, name)) + for box in bbox: + cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 1) + plt.subplot(1, 1, 1); plt.imshow(img[:, :, [2,1,0]]) + plt.show() + cv2.waitKey(0) - # for name in df.filename: - # img = cv2.imread(os.path.join(src_traindir, name)) - # box, _ = get_box_label(df, name) - # _boxvis(img, box) + +if __name__ == '__main__': + vis = False + + df = pd.read_csv(src_annotation) with open(os.path.join(list_dir, 'train.txt'), 'r') as f: img_list = [x.strip() for x in f.readlines()] - for name in img_list: - img = cv2.imread(os.path.join(image_dir, name+'.jpg')) - box = parse_xml(os.path.join(anno_dir, name+'.xml')) - _boxvis(img, box) + filter_list = [] + for i, name in enumerate(img_list): + sys.stdout.write('\rsearch: {:d}/{:d} {:s}' + .format(i + 1, len(img_list), img_list[i])) + sys.stdout.flush() -if __name__ == '__main__': - main() + orgin_name = name.split('_')[0]+'.jpg' + orgin_box, _ = get_box_label(df, orgin_name) + + label_w = orgin_box[0][2] - orgin_box[0][0] + if label_w > 140 and label_w < 145: + if vis: + img = cv2.imread(os.path.join(image_dir, name+'.jpg')) + box, width = parse_xml(os.path.join(anno_dir, name+'.xml')) + _boxvis(img, box) + _originvis(orgin_name, orgin_box) + else: + filter_list.append(orgin_name) + + if not vis: + filter_list = list(set(filter_list)) + if os.path.exists('temp'): + shutil.rmtree('temp') + os.mkdir('temp') + for name in filter_list: + img = cv2.imread(os.path.join(src_traindir, name)) + orgin_box, _ = get_box_label(df, name) + for box in orgin_box: + cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2) + cv2.imwrite('temp/'+name, img)