Skip to content

Commit

Permalink
添加检测数据集转换为lmdb工具
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Apr 4, 2020
1 parent f1001e7 commit 426c4fe
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 15 deletions.
111 changes: 111 additions & 0 deletions dataset/convert_det2lmdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
# @Time : 2020/4/2 14:19
# @Author : zhoujun

import os
import lmdb
import cv2
import numpy as np
import argparse
import shutil
import sys
from convert.utils import load_gt

def checkImageIsValid(imageBin):
if imageBin is None:
return False

try:
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
except:
return False
else:
if imgH * imgW == 0:
return False

return True


def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.items():
if type(k) == str:
k = k.encode()
if type(v) == str:
v = v.encode()
txn.put(k, v)


def createDataset(outputPath, data_dict, map_size=79951162, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
data_dict : a dict contains img_path,texts,text_polys
checkValid : if true, check the validity of every image
"""
# If lmdb file already exists, remove it. Or the new data will add to it.
if os.path.exists(outputPath):
shutil.rmtree(outputPath)
os.makedirs(outputPath)
else:
os.makedirs(outputPath)

nSamples = len(data_dict)
env = lmdb.open(outputPath, map_size=map_size)
cache = {}
cnt = 1
for img_path in data_dict:
data = data_dict[img_path]
if not os.path.exists(img_path):
print('%s does not exist' % img_path)
continue
with open(img_path, 'rb') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % img_path)
continue

imageKey = 'image-%09d' % cnt
polygonsKey = 'polygons-%09d' % cnt
textsKey = 'texts-%09d' % cnt
illegibilityKey = 'illegibility-%09d' % cnt
languageKey = 'language-%09d' % cnt
cache[imageKey] = imageBin
cache[polygonsKey] = np.array(data['polygons']).tostring()
cache[textsKey] = '\t'.join(data['texts'])
cache[illegibilityKey] = '\t'.join([str(x) for x in data['illegibility_list']])
cache[languageKey] = '\t'.join(data['language_list'])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt - 1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
env.close()
print('Created dataset with %d samples' % nSamples)


def show_demo(demo_number, image_path_list, label_list):
print('\nShow some demo to prevent creating wrong lmdb data')
print('The first line is the path to image and the second line is the image label')
for i in range(demo_number):
print('image: %s\nlabel: %s\n' % (image_path_list[i], label_list[i]))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
# parser.add_argument('--out', type = str, required = True, help = 'lmdb data output path')
parser.add_argument('--json_path', type=str, default='E:\\zj\\dataset\\icdar2015 (2)\\detection\\test.json',help='path to gt json')
parser.add_argument('--save_floder', type=str,default=r'E:\zj\dataset\icdar2015 (2)', help='path to save lmdb')
args = parser.parse_args()

data_dict = load_gt(args.json_path)
out_lmdb = os.path.join(args.save_floder,'train')
createDataset(out_lmdb, data_dict, map_size=79951162)
37 changes: 22 additions & 15 deletions dataset/det.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# @Time : 2020/3/24 11:36
# @Author : zhoujun
import os
import sys

project = 'OCR_DataSet' # 工作项目根目录
sys.path.append(os.getcwd().split(project)[0] + project)
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader

from convert.utils import load, show_bbox_on_image


class DetDataSet(Dataset):
def __init__(self, json_path, transform=None, target_transform=None):
self.data_list = self.load_data(json_path)
Expand Down Expand Up @@ -43,22 +46,22 @@ def load_data(self, json_path):
texts.append(char_annotation['char'])
illegibility_list.append(char_annotation['illegibility'])
language_list.append(char_annotation['language'])
d.append({'img_path': img_path, 'polygons': polygons, 'texts': texts,
'illegibility_list': illegibility_list,
'language_list': language_list})
d.append({'img_path': img_path, 'polygons': np.array(polygons), 'texts': texts,
'illegibility': illegibility_list,
'language': language_list})
return d

def __getitem__(self, item):
try:
item_dict = self.data_list[item]
item_dict['img'] = Image.open(item_dict['img_path']).convert('RGB')
item_dict['img'] = self.pre_processing(item_dict)
item_dict['label'] = self.make_label(item_dict)
item_dict['texts'] = self.make_label(item_dict)
# 进行标签制作
if self.transform:
item_dict['img'] = self.transform(item_dict['img'])
if self.target_transform:
item_dict['label'] = self.target_transform(item_dict['label'])
item_dict['texts'] = self.target_transform(item_dict['texts'])
return item_dict
except:
return self.__getitem__(np.random.randint(self.__len__()))
Expand All @@ -74,6 +77,7 @@ def pre_processing(self, item_dict):


if __name__ == '__main__':
import time
from tqdm import tqdm
from torchvision import transforms
from matplotlib import pyplot as plt
Expand All @@ -82,17 +86,20 @@ def pre_processing(self, item_dict):
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号

json_path = r'D:\dataset\icdar2017rctw\detection\train.json'
json_path = r'E:\\zj\\dataset\\icdar2015 (2)\\detection\\test.json'

dataset = DetDataSet(json_path, transform=transforms.ToTensor())
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0)
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=6)
pbar = tqdm(total=len(train_loader))
tic = time.time()
for i, data in enumerate(train_loader):
img = data['img'][0].numpy().transpose(1, 2, 0) * 255
label = [x[0] for x in data['label']]
pass
# img = data['img'][0].numpy().transpose(1, 2, 0) * 255
# texts = [x[0] for x in data['texts']]

img = show_bbox_on_image(Image.fromarray(img.astype(np.uint8)), data['polygons'], label)
plt.imshow(img)
plt.show()
pbar.update(1)
pbar.close()
# img = show_bbox_on_image(Image.fromarray(img.astype(np.uint8)), data['polygons'][0], label)
# plt.imshow(img)
# plt.show()
# pbar.update(1)
# pbar.close()
print(len(train_loader)/(time.time()-tic))
94 changes: 94 additions & 0 deletions dataset/det_lmdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
# @Time : 2020/4/2 18:41
# @Author : zhoujun
import lmdb
import six
import sys
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader,ConcatDataset


class lmdbDataset(Dataset):
def __init__(self, lmdb_path=None, transform=None, target_transform=None):
self.env = lmdb.open(lmdb_path, max_readers=12, readonly=True, lock=False, readahead=False, meminit=False)

if not self.env:
print('cannot creat lmdb from %s' % (lmdb_path))
sys.exit(0)

with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode('utf-8')))
self.nSamples = nSamples

self.transform = transform
self.target_transform = target_transform

def __len__(self):
return self.nSamples

def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
item = {}
with self.env.begin(write=False) as txn:
img_key = 'image-%09d' % index
imgbuf = txn.get(img_key.encode('utf-8'))

buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
try:
img = Image.open(buf).convert('RGB')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]

if self.transform is not None:
img = self.transform(img)
item['img'] = img
polygonsKey = 'polygons-%09d' % index
textsKey = 'texts-%09d' % index
illegibilityKey = 'illegibility-%09d' % index
languageKey = 'language-%09d' % index
polygons = txn.get(polygonsKey.encode('utf-8'))
item['polygons'] = np.frombuffer(polygons).reshape(-1, 4, 2)

item['texts'] = txn.get(textsKey.encode('utf-8')).decode().split('\t')
illegibility = txn.get(illegibilityKey.encode('utf-8')).decode().split('\t')
item['illegibility'] = [x.lower()=='true' for x in illegibility]
item['language'] = txn.get(languageKey.encode('utf-8')).decode().split('\t')

if self.target_transform is not None:
item['texts'] = self.target_transform(item['texts'])

return item


if __name__ == '__main__':
import time
from tqdm import tqdm
from torchvision import transforms
from matplotlib import pyplot as plt

# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号

lmdb_path = r'E:\zj\dataset\icdar2015 (2)\train'

dataset = lmdbDataset(lmdb_path, transform=transforms.ToTensor())
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0)
pbar = tqdm(total=len(train_loader))
tic = time.time()
for i, data in enumerate(train_loader):
pass
# img = data['img'][0].numpy().transpose(1, 2, 0) * 255
# label = [x[0] for x in data['texts']]
#
# img = show_bbox_on_image(Image.fromarray(img.astype(np.uint8)), data['polygons'][0], label)
# plt.imshow(img)
# plt.show()
# pbar.update(1)
# pbar.close()
print(len(train_loader) / (time.time() - tic))

0 comments on commit 426c4fe

Please sign in to comment.