Skip to content

Commit

Permalink
Merge pull request WenmuZhou#73 from WenmuZhou/dev
Browse files Browse the repository at this point in the history
添加不在词典的标注过滤
  • Loading branch information
WenmuZhou authored Jul 13, 2020
2 parents 866651e + 0b66833 commit 1827b04
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
7 changes: 5 additions & 2 deletions tools/rec_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_us

from torchocr.metrics import RecMetric
from torchocr.utils import CTCLabelConverter
with open(cfg.dataset.alphabet, 'r', encoding='utf-8') as file:
cfg.dataset.alphabet = ''.join([s.strip('\n') for s in file.readlines()])
converter = CTCLabelConverter(cfg.dataset.alphabet)
train_options = cfg.train_options
metric = RecMetric(converter)
Expand Down Expand Up @@ -316,8 +314,13 @@ def main():
loss_func = build_loss(cfg['loss'])
loss_func = loss_func.to(to_use_device)

with open(cfg.dataset.alphabet, 'r', encoding='utf-8') as file:
cfg.dataset.alphabet = ''.join([s.strip('\n') for s in file.readlines()])

# ===> data loader
cfg.dataset.train.dataset.alphabet = cfg.dataset.alphabet
train_loader = build_dataloader(cfg.dataset.train)
cfg.dataset.eval.dataset.alphabet = cfg.dataset.alphabet
eval_loader = build_dataloader(cfg.dataset.eval)

# ===> train
Expand Down
10 changes: 7 additions & 3 deletions torchocr/datasets/RecDataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ def __init__(self, config):
"""
self.augmentation = config.augmentation
self.process = RecDataProcess(config)

self.str2idx = {c: i for i, c in enumerate(config.alphabet)}
self.labels = []
with open(config.file, 'r', encoding='utf-8') as f_reader:
for m_line in f_reader.readlines():
params = m_line.strip().split('\t')
if len(params) == 2:
m_image_name, m_gt_text = params
if True in [c not in self.str2idx for c in m_gt_text]:
continue
self.labels.append((m_image_name, m_gt_text))

def _find_max_length(self):
Expand Down Expand Up @@ -79,22 +81,24 @@ def __init__(self, config):
self.process = RecDataProcess(config)
self.filtered_index_list = []
self.labels = []
self.str2idx = {c: i for i, c in enumerate(config.alphabet)}
with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
self.nSamples = nSamples
for index in range(self.nSamples):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
self.labels.append(label)
# todo 添加 过滤最长
# if len(label) > config.max_len:
# # print(f'The length of the label is longer than max_length: length
# # {len(label)}, {label} in dataset {self.root}')
# continue

if True in [c not in self.str2idx for c in label]:
continue
# By default, images containing characters which are not in opt.character are filtered.
# You can add [UNK] token to `opt.character` in utils.py instead of this filtering.
self.labels.append(label)
self.filtered_index_list.append(index)

def _find_max_length(self):
Expand Down

0 comments on commit 1827b04

Please sign in to comment.