Skip to content

Commit

Permalink
add the user handwriting generation file
Browse files Browse the repository at this point in the history
  • Loading branch information
dai gang committed Jan 7, 2024
1 parent 4e452b6 commit af218ed
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 2 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ __pycache__
data/*
Saved/*
model_zoo/*.pth
auto_*
auto_*
.vscode
Generated
style_samples
33 changes: 33 additions & 0 deletions configs/CHINESE_USER.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
MODEL:
ENCODER_LAYERS: 2
WRI_DEC_LAYERS: 2
GLY_DEC_LAYERS: 2
NUM_HEAD_LAYERS: 1
NUM_IMGS: 15
NUM_GPUS: 1 # TODO, support multi GPUs
SOLVER:
BASE_LR: 0.0002
MAX_ITER: 200000
WARMUP_ITERS: 20000
TYPE: Adam # TODO, support optional optimizer
GRAD_L2_CLIP: 5.0
TRAIN:
ISTRAIN: True
IMS_PER_BATCH: 64
SNAPSHOT_BEGIN: 2000
SNAPSHOT_ITERS: 4000
VALIDATE_ITERS: 2000
VALIDATE_BEGIN: 2000
SEED: 1001
IMG_H: 64
IMG_W: 64
TEST:
ISTRAIN: False
IMG_H: 64
IMG_W: 64
DATA_LOADER:
NUM_THREADS: 8
CONCAT_GRID: True
TYPE: UserDataset
PATH: data
DATASET: CHINESE
31 changes: 30 additions & 1 deletion data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import lmdb
from utils.util import corrds2xys
import codecs
import glob
import cv2

transform_data = transforms.Compose([
transforms.ToTensor(),
Expand Down Expand Up @@ -196,4 +198,31 @@ def collate_fn_(self, batch_data):
output['coords_len'][i], output['len_gt'][i] = s, h
output['character_id'][i] = batch_data[i]['character_id']
output['writer_id'][i] = batch_data[i]['writer_id']
return output
return output


class UserDataset(Dataset):
def __init__(self, root='data', dataset='CHINESE', style_path='style_samples'):
data_path = os.path.join(root, script[dataset][0])
self.content = pickle.load(open(os.path.join(data_path, script[dataset][1]), 'rb')) #content samples
self.char_dict = pickle.load(open(os.path.join(data_path, 'character_dict.pkl'), 'rb'))
self.style_path = glob.glob(style_path+'/*.[jp][pn]g')

def __len__(self):
return len(self.char_dict)

def __getitem__(self, index):
char = self.char_dict[index] # content samples
char_img = self.content[char]
char_img = char_img/255. # Normalize pixel values between 0.0 and 1.0
img_list = []
for idx in range(len(self.style_path)):
style_img = cv2.imread(self.style_path[idx], flags=0)
style_img = cv2.resize(style_img, (64, 64))
style_img = style_img/255.
img_list.append(style_img)
img_list = np.expand_dims(np.array(img_list), 1)

return {'char_img': torch.Tensor(char_img).unsqueeze(0),
'img_list': torch.Tensor(img_list),
'char': char}
78 changes: 78 additions & 0 deletions user_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import argparse
import os
from parse_config import cfg, cfg_from_file, assert_and_infer_cfg
import torch
from data_loader.loader import UserDataset
import pickle
from models.model import SDT_Generator
import tqdm
from utils.util import writeCache, dxdynp_to_list, coords_render
import lmdb

def main(opt):
""" load config file into cfg"""
cfg_from_file(opt.cfg_file)
assert_and_infer_cfg()

"""setup data_loader instances"""
test_dataset = UserDataset(
cfg.DATA_LOADER.PATH, cfg.DATA_LOADER.DATASET, opt.style_path)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=cfg.TRAIN.IMS_PER_BATCH,
shuffle=True,
sampler=None,
drop_last=False,
num_workers=cfg.DATA_LOADER.NUM_THREADS)

os.makedirs(os.path.join(opt.save_dir), exist_ok=True)

"""build model architecture"""
model = SDT_Generator(num_encoder_layers=cfg.MODEL.ENCODER_LAYERS,
num_head_layers= cfg.MODEL.NUM_HEAD_LAYERS,
wri_dec_layers=cfg.MODEL.WRI_DEC_LAYERS,
gly_dec_layers= cfg.MODEL.GLY_DEC_LAYERS).to('cuda')
if len(opt.pretrained_model) > 0:
model_weight = torch.load(opt.pretrained_model)
model.load_state_dict(model_weight)
print('load pretrained model from {}'.format(opt.pretrained_model))
else:
raise IOError('input the correct checkpoint path')
model.eval()

"""setup the dataloader"""
batch_samples = len(test_loader)
data_iter = iter(test_loader)
with torch.no_grad():
for _ in tqdm.tqdm(range(batch_samples)):

data = next(data_iter)
# prepare input
img_list, char_img, char = data['img_list'].cuda(), \
data['char_img'].cuda(), data['char']
preds = model.inference(img_list, char_img, 120)
bs = char_img.shape[0]
SOS = torch.tensor(bs * [[0, 0, 1, 0, 0]]).unsqueeze(1).to(preds)
preds = torch.cat((SOS, preds), 1) # add the SOS token like GT
preds = preds.detach().cpu().numpy()

for i, pred in enumerate(preds):
"""Render the character images by connecting the coordinates"""
sk_pil = coords_render(preds[i], split=True, width=256, height=256, thickness=8, board=1)

save_path = os.path.join(opt.save_dir, char[i] +'.png')
try:
sk_pil.save(save_path)
except:
print('error. %s, %s' % (save_path, char[i]))


if __name__ == '__main__':
"""Parse input arguments"""
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', dest='cfg_file', default='configs/CHINESE_USER.yml',
help='Config file for training (and optionally testing)')
parser.add_argument('--dir', dest='save_dir', default='Generated/Chinese_User', help='target dir for storing the generated characters')
parser.add_argument('--pretrained_model', dest='pretrained_model', default='', required=True, help='continue train model')
parser.add_argument('--style_path', dest='style_path', default='style_samples', help='dir of style samples')
opt = parser.parse_args()
main(opt)

0 comments on commit af218ed

Please sign in to comment.