Skip to content

Commit

Permalink
Multi scale (PaddlePaddle#9837)
Browse files Browse the repository at this point in the history
* update for multi scale

* update for multi scale

* update for multi scale

* rm notes
  • Loading branch information
tink2123 authored Apr 28, 2023
1 parent ded3740 commit b306681
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 20 deletions.
24 changes: 17 additions & 7 deletions configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Global:
save_epoch_step: 10
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
Expand Down Expand Up @@ -73,7 +73,8 @@ Metric:

Train:
dataset:
name: SimpleDataSet
name: MultiScaleDataSet
ds_width: false
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
Expand All @@ -90,20 +91,25 @@ Train:
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
first_bs: &bs 128
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
loader:
shuffle: true
batch_size_per_card: 128
batch_size_per_card: *bs
drop_last: true
num_workers: 4
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
Expand All @@ -115,9 +121,13 @@ Eval:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
max_text_length: 100
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
use_space_char: True
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
eval_mode: True
- KeepKeys:
keep_keys:
- image
Expand All @@ -128,5 +138,5 @@ Eval:
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
batch_size_per_card: 1
num_workers: 4
18 changes: 14 additions & 4 deletions configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_hgnet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ Metric:

Train:
dataset:
name: SimpleDataSet
name: MultiScaleDataSet
ds_width: false
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
Expand All @@ -89,15 +89,21 @@ Train:
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
# divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
first_bs: &bs 128
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
loader:
shuffle: true
batch_size_per_card: 128
Expand All @@ -114,9 +120,13 @@ Eval:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
max_text_length: 100
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
use_space_char: True
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
eval_mode: True
- KeepKeys:
keep_keys:
- image
Expand Down
20 changes: 13 additions & 7 deletions ppocr/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@
import paddle.distributed as dist

from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.simple_dataset import SimpleDataSet, MultiScaleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR, LMDBDataSetTableMaster
from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
from ppocr.data.multi_scale_sampler import MultiScaleSampler

__all__ = ['build_dataloader', 'transform', 'create_operators']

Expand All @@ -55,7 +56,7 @@ def build_dataloader(config, mode, device, logger, seed=None):

support_dict = [
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet',
'LMDBDataSetSR', 'LMDBDataSetTableMaster'
'LMDBDataSetSR', 'LMDBDataSetTableMaster', 'MultiScaleDataSet'
]
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
Expand All @@ -76,11 +77,16 @@ def build_dataloader(config, mode, device, logger, seed=None):

if mode == "Train":
# Distribute data to multiple cards
batch_sampler = DistributedBatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
if 'sampler' in config[mode]:
config_sampler = config[mode]['sampler']
sampler_name = config_sampler.pop("name")
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
else:
batch_sampler = DistributedBatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
else:
# Distribute data to single card
batch_sampler = BatchSampler(
Expand Down
2 changes: 1 addition & 1 deletion ppocr/data/imaug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def create_operators(op_param_list, global_config=None):
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
if global_config is not None and "max_text_length" not in param:
param.update(global_config)
op = eval(op_name)(**param)
ops.append(op)
Expand Down
5 changes: 4 additions & 1 deletion ppocr/data/imaug/rec_img_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,20 @@ class RecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
eval_mode=False,
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.eval_mode = eval_mode
self.character_dict_path = character_dict_path
self.padding = padding

def __call__(self, data):
img = data['image']
if self.infer_mode and self.character_dict_path is not None:
if self.eval_mode or (self.infer_mode and
self.character_dict_path is not None):
norm_img, valid_ratio = resize_norm_img_chinese(img,
self.image_shape)
else:
Expand Down
171 changes: 171 additions & 0 deletions ppocr/data/multi_scale_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from paddle.io import Sampler
import paddle.distributed as dist

import numpy as np
import random
import math


class MultiScaleSampler(Sampler):
def __init__(self,
data_source,
scales,
first_bs=128,
fix_bs=True,
divided_factor=[8, 16],
is_training=True,
ratio_wh=0.8,
max_w=480.,
seed=None):
"""
multi scale samper
Args:
data_source(dataset)
scales(list): several scales for image resolution
first_bs(int): batch size for the first scale in scales
divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
is_training(boolean): mode
"""
# min. and max. spatial dimensions
self.data_source = data_source
self.data_idx_order_list = np.array(data_source.data_idx_order_list)
self.ds_width = data_source.ds_width
self.seed = data_source.seed
if self.ds_width:
self.wh_ratio = data_source.wh_ratio
self.wh_ratio_sort = data_source.wh_ratio_sort
self.n_data_samples = len(self.data_source)
self.ratio_wh = ratio_wh
self.max_w = max_w

if isinstance(scales[0], list):
width_dims = [i[0] for i in scales]
height_dims = [i[1] for i in scales]
elif isinstance(scales[0], int):
width_dims = scales
height_dims = scales
base_im_w = width_dims[0]
base_im_h = height_dims[0]
base_batch_size = first_bs

# Get the GPU and node related information
num_replicas = dist.get_world_size()
rank = dist.get_rank()
# adjust the total samples to avoid batch dropping
num_samples_per_replica = int(
math.ceil(self.n_data_samples * 1.0 / num_replicas))

img_indices = [idx for idx in range(self.n_data_samples)]

self.shuffle = False
if is_training:
# compute the spatial dimensions and corresponding batch size
# ImageNet models down-sample images by a factor of 32.
# Ensure that width and height dimensions are multiples are multiple of 32.
width_dims = [
int((w // divided_factor[0]) * divided_factor[0])
for w in width_dims
]
height_dims = [
int((h // divided_factor[1]) * divided_factor[1])
for h in height_dims
]

img_batch_pairs = list()
base_elements = base_im_w * base_im_h * base_batch_size
for (h, w) in zip(height_dims, width_dims):
if fix_bs:
batch_size = base_batch_size
else:
batch_size = int(max(1, (base_elements / (h * w))))
img_batch_pairs.append((w, h, batch_size))
self.img_batch_pairs = img_batch_pairs
self.shuffle = True
else:
self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]

self.img_indices = img_indices
self.n_samples_per_replica = num_samples_per_replica
self.epoch = 0
self.rank = rank
self.num_replicas = num_replicas

self.batch_list = []
self.current = 0
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
while self.current < self.n_samples_per_replica:
curr_w, curr_h, curr_bsz = random.choice(self.img_batch_pairs)

end_index = min(self.current + curr_bsz, self.n_samples_per_replica)

batch_ids = indices_rank_i[self.current:end_index]
n_batch_samples = len(batch_ids)
if n_batch_samples != curr_bsz:
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
self.current += curr_bsz

if len(batch_ids) > 0:
batch = [curr_w, curr_h, len(batch_ids)]
self.batch_list.append(batch)
self.length = len(self.batch_list)
self.batchs_in_one_epoch = self.iter()
self.batchs_in_one_epoch_id = [
i for i in range(len(self.batchs_in_one_epoch))
]

def __iter__(self):
if self.seed is None:
random.seed(self.epoch)
self.epoch += 1
else:
random.seed(self.seed)
random.shuffle(self.batchs_in_one_epoch_id)
for batch_tuple_id in self.batchs_in_one_epoch_id:
yield self.batchs_in_one_epoch[batch_tuple_id]

def iter(self):
if self.shuffle:
if self.seed is not None:
random.seed(self.seed)
else:
random.seed(self.epoch)
if not self.ds_width:
random.shuffle(self.img_indices)
random.shuffle(self.img_batch_pairs)
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
else:
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]

start_index = 0
batchs_in_one_epoch = []
for batch_tuple in self.batch_list:
curr_w, curr_h, curr_bsz = batch_tuple
end_index = min(start_index + curr_bsz, self.n_samples_per_replica)
batch_ids = indices_rank_i[start_index:end_index]
n_batch_samples = len(batch_ids)
if n_batch_samples != curr_bsz:
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
start_index += curr_bsz

if len(batch_ids) > 0:
if self.ds_width:
wh_ratio_current = self.wh_ratio[self.wh_ratio_sort[
batch_ids]]
ratio_current = wh_ratio_current.mean()
ratio_current = ratio_current if ratio_current * curr_h < self.max_w else self.max_w / curr_h
else:
ratio_current = None
batch = [(curr_w, curr_h, b_id, ratio_current)
for b_id in batch_ids]
# yield batch
batchs_in_one_epoch.append(batch)
return batchs_in_one_epoch

def set_epoch(self, epoch: int):
self.epoch = epoch

def __len__(self):
return self.length
Loading

0 comments on commit b306681

Please sign in to comment.