Skip to content

Commit

Permalink
det.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Jul 7, 2020
1 parent cc6c395 commit dd24de4
Showing 1 changed file with 20 additions and 23 deletions.
43 changes: 20 additions & 23 deletions dataset/det.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def load_data(self, json_path):
illegibility_list = []
language_list = []
for annotation in gt['annotations']:
if len(annotation['polygon']) == 0 or len(annotation['text']) == 0:
if len(annotation['polygon']) == 0:
continue
polygons.append(annotation['polygon'])
texts.append(annotation['text'])
Expand All @@ -52,19 +52,16 @@ def load_data(self, json_path):
return d

def __getitem__(self, item):
try:
item_dict = self.data_list[item]
item_dict['img'] = Image.open(item_dict['img_path']).convert('RGB')
item_dict['img'] = self.pre_processing(item_dict)
item_dict['texts'] = self.make_label(item_dict)
# 进行标签制作
if self.transform:
item_dict['img'] = self.transform(item_dict['img'])
if self.target_transform:
item_dict['texts'] = self.target_transform(item_dict['texts'])
return item_dict
except:
return self.__getitem__(np.random.randint(self.__len__()))
item_dict = self.data_list[item]
item_dict['img'] = Image.open(item_dict['img_path']).convert('RGB')
item_dict['img'] = self.pre_processing(item_dict)
item_dict['texts'] = self.make_label(item_dict)
# 进行标签制作
if self.transform:
item_dict['img'] = self.transform(item_dict['img'])
if self.target_transform:
item_dict['texts'] = self.target_transform(item_dict['texts'])
return item_dict

def __len__(self):
return len(self.data_list)
Expand All @@ -86,20 +83,20 @@ def pre_processing(self, item_dict):
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号

json_path = r'E:\\zj\\dataset\\icdar2015 (2)\\detection\\test.json'
json_path = r'D:\dataset\自然场景文字检测挑战赛初赛数据\训练集\\train.json'

dataset = DetDataSet(json_path, transform=transforms.ToTensor())
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=6)
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0)
pbar = tqdm(total=len(train_loader))
tic = time.time()
for i, data in enumerate(train_loader):
pass
# img = data['img'][0].numpy().transpose(1, 2, 0) * 255
# texts = [x[0] for x in data['texts']]
img = data['img'][0].numpy().transpose(1, 2, 0) * 255
texts = [x[0] for x in data['texts']]

# img = show_bbox_on_image(Image.fromarray(img.astype(np.uint8)), data['polygons'][0], label)
# plt.imshow(img)
# plt.show()
# pbar.update(1)
# pbar.close()
img = show_bbox_on_image(Image.fromarray(img.astype(np.uint8)), data['polygons'][0],texts)
plt.imshow(img)
plt.show()
pbar.update(1)
pbar.close()
print(len(train_loader)/(time.time()-tic))

0 comments on commit dd24de4

Please sign in to comment.