forked from MulongXie/UIED
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_select_classification.py
60 lines (51 loc) · 1.97 KB
/
main_select_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import multiprocessing
import glob
import time
import json
from tqdm import tqdm
from os.path import join as pjoin, exists
import ip_region_proposal as ip
from CONFIG import Config
if __name__ == '__main__':
# initialization
C = Config()
resize_by_height = 800
input_root = C.ROOT_INPUT
output_root = C.ROOT_OUTPUT
# set input root directory and sort all images by their indices
data = json.load(open('E:\\Mulong\\Datasets\\rico\\instances_test.json', 'r'))
input_paths_img = [pjoin(input_root, img['file_name'].split('/')[-1]) for img in data['images']]
input_paths_img = sorted(input_paths_img, key=lambda x: int(x.split('\\')[-1][:-4])) # sorted by index
is_ip = True
is_ocr = False
is_merge = True
# switch of the classification func
classifier = None
if is_ip:
is_clf = True
if is_clf:
classifier = {}
from CNN import CNN
classifier['Image'] = CNN('Image')
classifier['Elements'] = CNN('Elements')
# set the range of target inputs' indices
num = 4202
start_index = 65010 # 61728
end_index = 100000
for input_path_img in input_paths_img:
index = input_path_img.split('\\')[-1][:-4]
if int(index) < start_index:
continue
if int(index) > end_index:
break
if is_ocr:
import ocr_east as ocr
ocr.east(input_path_img, output_root, resize_by_height=None, show=False, write_img=True)
if is_ip:
ip.compo_detection(input_path_img, output_root, num, resize_by_height=resize_by_height, show=False, classifier=classifier)
if is_merge:
import merge
compo_path = pjoin(output_root, 'ip', str(index) + '.json')
ocr_path = pjoin(output_root, 'ocr', str(index) + '.json')
merge.incorporate(input_path_img, compo_path, ocr_path, output_root, resize_by_height=resize_by_height, show=False, write_img=True)
num += 1