-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Yu Qian
committed
Sep 26, 2019
1 parent
3604fcc
commit 1c4ae30
Showing
31 changed files
with
927 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# SRNet - Editing Text in the Wild | ||
|
||
| ||
## Introduction | ||
This is a reproducing of paper *Editing Text in the wild* by tensorflow, which aims to replace or modify a word in the source image with another one while maintaining its realistic look. | ||
|
||
Original paper: [*Editing Text in the wild*](https://arxiv.org/abs/1908.03047) by Liang Wu, Chengquan Zhang, Jiaming Liu, Junyu Han, Jingtuo Liu, Errui Ding and Xiang Bai. | ||
|
||
The model in this project is a result of my experiment and debugging of the details described in the paper. | ||
|
||
A pre-trained vgg19 model is used in this SRNet, which is downloaded from [https://github.com/fchollet/deep-learning-models/releases/tag/v0.1](https://github.com/fchollet/deep-learning-models/releases/tag/v0.1) and converted to pb format | ||
|
||
![image](./examples/example/example.png) | ||
|
||
| ||
## Prepare data | ||
Data is completely prepared as described in the paper. | ||
|
||
You can refer to and improve [Synthtext](https://github.com/ankush-me/SynthText) project to render styled texts on background images. Also need to save some intermediate results as labels while rendering. | ||
|
||
According to this paper, you need to prepare 2 input images(i_s, i_t) and 4 label images(t_sk, t_t, t_b, t_f) | ||
|
||
- `i_s`: standard text b rendering on gray background | ||
|
||
- `i_t`: styled text a rendering on background image | ||
|
||
- `t_sk`: skeletonization of styled text b. | ||
|
||
- `t_t`: styled text b rendering on gray background | ||
|
||
- `t_b`: background image | ||
|
||
- `t_f`: styled text b rendering on background image | ||
|
||
In my experiment, I found it easier to train with one more label data(mask_t). | ||
|
||
- `mask_t`: the binary mask of styled text b | ||
|
||
![image](./examples/example/data.png) | ||
|
||
From left to right, from top to bottom are examples of `i_s, i_t, t_sk, t_t, t_b, t_f, mask_t` | ||
|
||
| ||
## Train your own dataset | ||
First clone this project | ||
```basrc | ||
$ git clone https://github.com/youdao-ai/SRNet.git | ||
``` | ||
|
||
Once the data is ready, put the images in different directories with the same name. | ||
|
||
You can modify the data directories and training parameters in `cfg.py` as you want. | ||
|
||
Then run `python3 train.py` to start training. | ||
|
||
| ||
## Predict | ||
You can predict your own data with | ||
```bashrc | ||
$ python3 predict.py --i_s xxx --i_t xxx --save_dir xxx --checkpoint xxx | ||
``` | ||
If you want to predict a directory of data, just make sure your data `i_s` and `i_t` have the same prefix and splited by '_', for example, `image1_i_s.png` and `image1_i_t.png`, put them into one directory and | ||
```bashrc | ||
$ python3 predict.py --input_dir xxx --save_dir xxx --checkpoint xxx | ||
``` | ||
|
||
Or you can set these path information in `cfg.py` and just `python3 predict.py` | ||
|
||
| ||
## Requirements | ||
- Python 3.6 | ||
|
||
- numpy | ||
|
||
- opencv-python | ||
|
||
- tensorflow 1.14.0 | ||
|
||
| ||
## Reference | ||
- [Editing Text in the Wild](https://arxiv.org/abs/1908.03047) | ||
|
||
- [EnsNet: Ensconce Text in the Wild](https://arxiv.org/abs/1812.00723) | ||
|
||
- [Synthetic Data for Text Localisation in Natural Images](https://arxiv.org/abs/1604.06646) | ||
|
||
- [A fast parallel algorithm for thinning digital patterns](http://www-prima.inrialpes.fr/perso/Tran/Draft/gateway.cfm.pdf) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
""" | ||
SRNet - Editing Text in the Wild | ||
Some configurations. | ||
Copyright (c) 2019 Netease Youdao Information Technology Co.,Ltd. | ||
Licensed under the GPL License (see LICENSE for details) | ||
Written by Yu Qian | ||
""" | ||
|
||
# device | ||
gpu = 0 | ||
|
||
# pretrained vgg | ||
vgg19_weights = 'model_logs/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.pb' | ||
|
||
# model parameters | ||
lt = 1. | ||
lt_alpha = 1. | ||
lb = 1. | ||
lb_beta = 10. | ||
lf = 1. | ||
lf_theta_1 = 10. | ||
lf_theta_2 = 1. | ||
lf_theta_3 = 500. | ||
epsilon = 1e-8 | ||
|
||
# train | ||
learning_rate = 1e-4 # default 1e-3 | ||
decay_rate = 0.9 | ||
decay_steps = 10000 | ||
staircase = False | ||
beta1 = 0.9 # default 0.9 | ||
beta2 = 0.999 # default 0.999 | ||
max_iter = 500000 | ||
show_loss_interval = 50 | ||
write_log_interval = 50 | ||
save_ckpt_interval = 10000 | ||
gen_example_interval = 1000 | ||
checkpoint_savedir = 'model_logs/checkpoints' | ||
tensorboard_dir = 'model_logs/train_logs' | ||
pretrained_ckpt_path = None | ||
train_name = None # used for name examples and tensorboard logdirs, set None to use time | ||
|
||
# data | ||
batch_size = 8 | ||
data_shape = [64, None] | ||
data_dir = '/reserve/qianyu/gpu100/qy/datasets/srnet_data' | ||
i_t_dir = 'i_t_2' | ||
i_s_dir = 'i_s_1' | ||
t_sk_dir = 't_sk_2' | ||
t_t_dir = 't_t_2' | ||
t_b_dir = 't_b' | ||
t_f_dir = 'i_s_2' | ||
mask_t_dir = 't_sk_tmp_2' | ||
example_data_dir = 'examples/labels' | ||
example_result_dir = 'examples/gen_logs' | ||
|
||
# predict | ||
predict_ckpt_path = None | ||
predict_data_dir = None | ||
predict_result_dir = 'examples/result' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
""" | ||
SRNet - Editing Text in the Wild | ||
Data generator. | ||
Copyright (c) 2019 Netease Youdao Information Technology Co.,Ltd. | ||
Licensed under the GPL License (see LICENSE for details) | ||
Written by Yu Qian | ||
""" | ||
|
||
import os | ||
import cv2 | ||
import numpy as np | ||
import random | ||
import cfg | ||
|
||
def srnet_datagen(): | ||
|
||
# generator SRNet data for training | ||
name_list = os.listdir(os.path.join(cfg.data_dir, cfg.t_b_dir)) | ||
random.shuffle(name_list) | ||
name_num = len(name_list) | ||
idx = 0 | ||
|
||
while True: | ||
i_t_batch, i_s_batch = [], [] | ||
t_sk_batch, t_t_batch, t_b_batch, t_f_batch = [], [], [], [] | ||
mask_t_batch = [] | ||
|
||
for _ in range(cfg.batch_size): | ||
name = name_list[idx] | ||
|
||
i_t = cv2.imread(os.path.join(cfg.data_dir, cfg.i_t_dir, name)) | ||
i_s = cv2.imread(os.path.join(cfg.data_dir, cfg.i_s_dir, name)) | ||
t_sk = cv2.imread(os.path.join(cfg.data_dir, cfg.t_sk_dir, name)) | ||
t_sk = cv2.cvtColor(t_sk, cv2.COLOR_RGB2GRAY) | ||
t_t = cv2.imread(os.path.join(cfg.data_dir, cfg.t_t_dir, name)) | ||
t_b = cv2.imread(os.path.join(cfg.data_dir, cfg.t_b_dir, name)) | ||
t_f = cv2.imread(os.path.join(cfg.data_dir, cfg.t_f_dir, name)) | ||
mask_t = cv2.imread(os.path.join(cfg.data_dir, cfg.mask_t_dir, name)) | ||
mask_t = cv2.cvtColor(mask_t, cv2.COLOR_RGB2GRAY) | ||
|
||
i_t_batch.append(i_t) | ||
i_s_batch.append(i_s) | ||
t_sk_batch.append(t_sk) | ||
t_t_batch.append(t_t) | ||
t_b_batch.append(t_b) | ||
t_f_batch.append(t_f) | ||
mask_t_batch.append(mask_t) | ||
idx = (idx + 1) % name_num | ||
|
||
w_sum = 0 | ||
for t_b in t_b_batch: | ||
h, w = t_b.shape[:2] | ||
scale_ratio = cfg.data_shape[0] / h | ||
w_sum += int(w * scale_ratio) | ||
|
||
to_h = cfg.data_shape[0] | ||
to_w = w_sum // cfg.batch_size | ||
to_w = int(round(to_w / 8)) * 8 | ||
to_scale = (to_w, to_h) # w first for cv2 | ||
for i in range(cfg.batch_size): | ||
i_t_batch[i] = cv2.resize(i_t_batch[i], to_scale) | ||
i_s_batch[i] = cv2.resize(i_s_batch[i], to_scale) | ||
t_sk_batch[i] = cv2.resize(t_sk_batch[i], to_scale, interpolation=cv2.INTER_NEAREST) | ||
t_t_batch[i] = cv2.resize(t_t_batch[i], to_scale) | ||
t_b_batch[i] = cv2.resize(t_b_batch[i], to_scale) | ||
t_f_batch[i] = cv2.resize(t_f_batch[i], to_scale) | ||
mask_t_batch[i] = cv2.resize(mask_t_batch[i], to_scale, interpolation=cv2.INTER_NEAREST) | ||
|
||
i_t_batch = np.stack(i_t_batch) | ||
i_s_batch = np.stack(i_s_batch) | ||
t_sk_batch = np.expand_dims(np.stack(t_sk_batch), axis = -1) | ||
t_t_batch = np.stack(t_t_batch) | ||
t_b_batch = np.stack(t_b_batch) | ||
t_f_batch = np.stack(t_f_batch) | ||
mask_t_batch = np.expand_dims(np.stack(mask_t_batch), axis = -1) | ||
|
||
i_t_batch = i_t_batch.astype(np.float32) / 127.5 - 1. | ||
i_s_batch = i_s_batch.astype(np.float32) / 127.5 - 1. | ||
t_sk_batch = t_sk_batch.astype(np.float32) / 255. | ||
t_t_batch = t_t_batch.astype(np.float32) / 127.5 - 1. | ||
t_b_batch = t_b_batch.astype(np.float32) / 127.5 - 1. | ||
t_f_batch = t_f_batch.astype(np.float32) / 127.5 - 1. | ||
mask_t_batch = mask_t_batch.astype(np.float32) / 255. | ||
|
||
yield [i_t_batch, i_s_batch, t_sk_batch, t_t_batch, t_b_batch, t_f_batch, mask_t_batch] | ||
|
||
def get_input_data(data_dir = cfg.example_data_dir): | ||
|
||
# get input data from dir | ||
data_list = os.listdir(data_dir) | ||
data_list = [data_name.split('_')[0] + '_' for data_name in data_list] | ||
data_list = list(set(data_list)) | ||
res_list = [] | ||
for data_name in data_list: | ||
i_t = cv2.imread(os.path.join(cfg.example_data_dir, data_name + 'i_t.png')) | ||
i_s = cv2.imread(os.path.join(cfg.example_data_dir, data_name + 'i_s.png')) | ||
h, w = i_t.shape[:2] | ||
scale_ratio = cfg.data_shape[0] / h | ||
to_h = cfg.data_shape[0] | ||
to_w = int(round(int(w * scale_ratio) / 8)) * 8 | ||
to_scale = (to_w, to_h) # w first for cv2 | ||
i_t = cv2.resize(i_t, to_scale).astype(np.float32) / 127.5 - 1. | ||
i_s = cv2.resize(i_s, to_scale).astype(np.float32) / 127.5 - 1. | ||
i_t = np.expand_dims(i_t, axis = 0) | ||
i_s = np.expand_dims(i_s, axis = 0) | ||
res_list.append([i_t, i_s, (w, h), data_name]) # w first for cv2 | ||
return res_list |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
""" | ||
SRNet - Editing Text in the Wild | ||
Definition of loss functions. | ||
Copyright (c) 2019 Netease Youdao Information Technology Co.,Ltd. | ||
Licensed under the GPL License (see LICENSE for details) | ||
Written by Yu Qian | ||
""" | ||
|
||
import tensorflow as tf | ||
import cfg | ||
|
||
def build_discriminator_loss(x, name = 'd_loss'): | ||
|
||
x_true, x_pred = tf.split(x, 2, name = name + '_split') | ||
d_loss = -tf.reduce_mean(tf.log(tf.clip_by_value(x_true, cfg.epsilon, 1.0)) \ | ||
+ tf.log(tf.clip_by_value(1.0 - x_pred, cfg.epsilon, 1.0))) | ||
return d_loss | ||
|
||
def build_dice_loss(x_t, x_o, name = 'dice_loss'): | ||
|
||
intersection = tf.reduce_sum(x_t * x_o, axis = [1,2,3]) | ||
union = tf.reduce_sum(x_t, axis = [1,2,3]) + tf.reduce_sum(x_o, axis = [1,2,3]) | ||
return 1. - tf.reduce_mean((2. * intersection + cfg.epsilon)/(union + cfg.epsilon), axis = 0) | ||
|
||
def build_l1_loss(x_t, x_o, name = 'l1_loss'): | ||
|
||
return tf.reduce_mean(tf.abs(x_t - x_o)) | ||
|
||
def build_l1_loss_with_mask(x_t, x_o, mask, name = 'l1_loss'): | ||
|
||
mask_ratio = 1. - tf.reduce_sum(mask) / tf.cast(tf.size(mask), tf.float32) | ||
l1 = tf.abs(x_t - x_o) | ||
return mask_ratio * tf.reduce_mean(l1 * mask) + (1. - mask_ratio) * tf.reduce_mean(l1 * (1. - mask)) | ||
|
||
def build_perceptual_loss(x, name = 'per_loss'): | ||
|
||
l = [] | ||
for i, f in enumerate(x): | ||
l.append(build_l1_loss(f[0], f[1], name = name + '_l1_' + str(i + 1))) | ||
l = tf.stack(l, axis = 0, name = name + '_stack') | ||
l = tf.reduce_sum(l, name = name + '_sum') | ||
return l | ||
|
||
def build_gram_matrix(x, name = 'gram_matrix'): | ||
|
||
x_shape = tf.shape(x) | ||
h, w, c = x_shape[1], x_shape[2], x_shape[3] | ||
matrix = tf.reshape(x, shape = [-1, h * w, c]) | ||
gram = tf.matmul(matrix, matrix, transpose_a = True) / tf.cast(h * w * c, tf.float32) | ||
return gram | ||
|
||
def build_style_loss(x, name = 'style_loss'): | ||
|
||
l = [] | ||
for i, f in enumerate(x): | ||
f_shape = tf.size(f[0]) | ||
f_norm = 1. / tf.cast(f_shape, tf.float32) | ||
gram_true = build_gram_matrix(f[0], name = name + '_gram_true_' + str(i + 1)) | ||
gram_pred = build_gram_matrix(f[1], name = name + '_gram_pred_' + str(i + 1)) | ||
l.append(f_norm * (build_l1_loss(gram_true, gram_pred, name = name + '_l1_' + str(i + 1)))) | ||
l = tf.stack(l, axis = 0, name = name + '_stack') | ||
l = tf.reduce_sum(l, name = name + '_sum') | ||
return l | ||
|
||
def build_vgg_loss(x, name = 'vgg_loss'): | ||
|
||
splited = [] | ||
for i, f in enumerate(x): | ||
splited.append(tf.split(f, 2, name = name + '_split_' + str(i + 1))) | ||
l_per = build_perceptual_loss(splited, name = name + '_per') | ||
l_style = build_style_loss(splited, name = name + '_style') | ||
return l_per, l_style | ||
|
||
def build_gan_loss(x, name = 'gan_loss'): | ||
|
||
x_true, x_pred = tf.split(x, 2, name = name + '_split') | ||
gan_loss = -tf.reduce_mean(tf.log(tf.clip_by_value(x_pred, cfg.epsilon, 1.0))) | ||
return gan_loss | ||
|
||
def build_generator_loss(out_g, out_d, out_vgg, labels, name = 'g_loss'): | ||
|
||
o_sk, o_t, o_b, o_f, mask_t = out_g | ||
o_db, o_df = out_d | ||
o_vgg = out_vgg | ||
t_sk, t_t, t_b, t_f = labels | ||
|
||
l_t_sk = cfg.lt_alpha * build_dice_loss(t_sk, o_sk, name = name + '_dice_loss') | ||
l_t_l1 = build_l1_loss_with_mask(t_t, o_t, mask_t, name = name + '_lt_l1_loss') | ||
l_t = l_t_l1 + l_t_sk | ||
|
||
l_b_gan = build_gan_loss(o_db, name = name + '_lb_gan_loss') | ||
l_b_l1 = cfg.lb_beta * build_l1_loss(t_b, o_b, name = name + '_lb_l1_loss') | ||
l_b = l_b_gan + l_b_l1 | ||
|
||
l_f_gan = build_gan_loss(o_df, name = name + '_lf_gan_loss') | ||
l_f_l1 = cfg.lf_theta_1 * build_l1_loss(t_f, o_f, name = name + '_lf_l1_loss') | ||
l_f_vgg_per, l_f_vgg_style = build_vgg_loss(o_vgg, name = name + '_lf_vgg_loss') | ||
l_f_vgg_per = cfg.lf_theta_2 * l_f_vgg_per | ||
l_f_vgg_style = cfg.lf_theta_3 * l_f_vgg_style | ||
l_f = l_f_gan + l_f_l1 + l_f_vgg_per + l_f_vgg_style | ||
|
||
l = cfg.lt * l_t + cfg.lb * l_b + cfg.lf * l_f | ||
return l, [l_t_sk, l_t_l1, l_b_gan, l_b_l1, l_f_gan, l_f_l1, l_f_vgg_per, l_f_vgg_style] |
Oops, something went wrong.