Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu Qian committed Sep 26, 2019
1 parent 3604fcc commit 1c4ae30
Show file tree
Hide file tree
Showing 31 changed files with 927 additions and 0 deletions.
88 changes: 88 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SRNet - Editing Text in the Wild

 
## Introduction
This is a reproducing of paper *Editing Text in the wild* by tensorflow, which aims to replace or modify a word in the source image with another one while maintaining its realistic look.

Original paper: [*Editing Text in the wild*](https://arxiv.org/abs/1908.03047) by Liang Wu, Chengquan Zhang, Jiaming Liu, Junyu Han, Jingtuo Liu, Errui Ding and Xiang Bai.

The model in this project is a result of my experiment and debugging of the details described in the paper.

A pre-trained vgg19 model is used in this SRNet, which is downloaded from [https://github.com/fchollet/deep-learning-models/releases/tag/v0.1](https://github.com/fchollet/deep-learning-models/releases/tag/v0.1) and converted to pb format

![image](./examples/example/example.png)

 
## Prepare data
Data is completely prepared as described in the paper.

You can refer to and improve [Synthtext](https://github.com/ankush-me/SynthText) project to render styled texts on background images. Also need to save some intermediate results as labels while rendering.

According to this paper, you need to prepare 2 input images(i_s, i_t) and 4 label images(t_sk, t_t, t_b, t_f)

- `i_s`: standard text b rendering on gray background

- `i_t`: styled text a rendering on background image

- `t_sk`: skeletonization of styled text b.

- `t_t`: styled text b rendering on gray background

- `t_b`: background image

- `t_f`: styled text b rendering on background image

In my experiment, I found it easier to train with one more label data(mask_t).

- `mask_t`: the binary mask of styled text b

![image](./examples/example/data.png)

From left to right, from top to bottom are examples of `i_s, i_t, t_sk, t_t, t_b, t_f, mask_t`

 
## Train your own dataset
First clone this project
```basrc
$ git clone https://github.com/youdao-ai/SRNet.git
```

Once the data is ready, put the images in different directories with the same name.

You can modify the data directories and training parameters in `cfg.py` as you want.

Then run `python3 train.py` to start training.

 
## Predict
You can predict your own data with
```bashrc
$ python3 predict.py --i_s xxx --i_t xxx --save_dir xxx --checkpoint xxx
```
If you want to predict a directory of data, just make sure your data `i_s` and `i_t` have the same prefix and splited by '_', for example, `image1_i_s.png` and `image1_i_t.png`, put them into one directory and
```bashrc
$ python3 predict.py --input_dir xxx --save_dir xxx --checkpoint xxx
```

Or you can set these path information in `cfg.py` and just `python3 predict.py`

 
## Requirements
- Python 3.6

- numpy

- opencv-python

- tensorflow 1.14.0

 
## Reference
- [Editing Text in the Wild](https://arxiv.org/abs/1908.03047)

- [EnsNet: Ensconce Text in the Wild](https://arxiv.org/abs/1812.00723)

- [Synthetic Data for Text Localisation in Natural Images](https://arxiv.org/abs/1604.06646)

- [A fast parallel algorithm for thinning digital patterns](http://www-prima.inrialpes.fr/perso/Tran/Draft/gateway.cfm.pdf)

60 changes: 60 additions & 0 deletions cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
SRNet - Editing Text in the Wild
Some configurations.
Copyright (c) 2019 Netease Youdao Information Technology Co.,Ltd.
Licensed under the GPL License (see LICENSE for details)
Written by Yu Qian
"""

# device
gpu = 0

# pretrained vgg
vgg19_weights = 'model_logs/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.pb'

# model parameters
lt = 1.
lt_alpha = 1.
lb = 1.
lb_beta = 10.
lf = 1.
lf_theta_1 = 10.
lf_theta_2 = 1.
lf_theta_3 = 500.
epsilon = 1e-8

# train
learning_rate = 1e-4 # default 1e-3
decay_rate = 0.9
decay_steps = 10000
staircase = False
beta1 = 0.9 # default 0.9
beta2 = 0.999 # default 0.999
max_iter = 500000
show_loss_interval = 50
write_log_interval = 50
save_ckpt_interval = 10000
gen_example_interval = 1000
checkpoint_savedir = 'model_logs/checkpoints'
tensorboard_dir = 'model_logs/train_logs'
pretrained_ckpt_path = None
train_name = None # used for name examples and tensorboard logdirs, set None to use time

# data
batch_size = 8
data_shape = [64, None]
data_dir = '/reserve/qianyu/gpu100/qy/datasets/srnet_data'
i_t_dir = 'i_t_2'
i_s_dir = 'i_s_1'
t_sk_dir = 't_sk_2'
t_t_dir = 't_t_2'
t_b_dir = 't_b'
t_f_dir = 'i_s_2'
mask_t_dir = 't_sk_tmp_2'
example_data_dir = 'examples/labels'
example_result_dir = 'examples/gen_logs'

# predict
predict_ckpt_path = None
predict_data_dir = None
predict_result_dir = 'examples/result'
107 changes: 107 additions & 0 deletions data_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
SRNet - Editing Text in the Wild
Data generator.
Copyright (c) 2019 Netease Youdao Information Technology Co.,Ltd.
Licensed under the GPL License (see LICENSE for details)
Written by Yu Qian
"""

import os
import cv2
import numpy as np
import random
import cfg

def srnet_datagen():

# generator SRNet data for training
name_list = os.listdir(os.path.join(cfg.data_dir, cfg.t_b_dir))
random.shuffle(name_list)
name_num = len(name_list)
idx = 0

while True:
i_t_batch, i_s_batch = [], []
t_sk_batch, t_t_batch, t_b_batch, t_f_batch = [], [], [], []
mask_t_batch = []

for _ in range(cfg.batch_size):
name = name_list[idx]

i_t = cv2.imread(os.path.join(cfg.data_dir, cfg.i_t_dir, name))
i_s = cv2.imread(os.path.join(cfg.data_dir, cfg.i_s_dir, name))
t_sk = cv2.imread(os.path.join(cfg.data_dir, cfg.t_sk_dir, name))
t_sk = cv2.cvtColor(t_sk, cv2.COLOR_RGB2GRAY)
t_t = cv2.imread(os.path.join(cfg.data_dir, cfg.t_t_dir, name))
t_b = cv2.imread(os.path.join(cfg.data_dir, cfg.t_b_dir, name))
t_f = cv2.imread(os.path.join(cfg.data_dir, cfg.t_f_dir, name))
mask_t = cv2.imread(os.path.join(cfg.data_dir, cfg.mask_t_dir, name))
mask_t = cv2.cvtColor(mask_t, cv2.COLOR_RGB2GRAY)

i_t_batch.append(i_t)
i_s_batch.append(i_s)
t_sk_batch.append(t_sk)
t_t_batch.append(t_t)
t_b_batch.append(t_b)
t_f_batch.append(t_f)
mask_t_batch.append(mask_t)
idx = (idx + 1) % name_num

w_sum = 0
for t_b in t_b_batch:
h, w = t_b.shape[:2]
scale_ratio = cfg.data_shape[0] / h
w_sum += int(w * scale_ratio)

to_h = cfg.data_shape[0]
to_w = w_sum // cfg.batch_size
to_w = int(round(to_w / 8)) * 8
to_scale = (to_w, to_h) # w first for cv2
for i in range(cfg.batch_size):
i_t_batch[i] = cv2.resize(i_t_batch[i], to_scale)
i_s_batch[i] = cv2.resize(i_s_batch[i], to_scale)
t_sk_batch[i] = cv2.resize(t_sk_batch[i], to_scale, interpolation=cv2.INTER_NEAREST)
t_t_batch[i] = cv2.resize(t_t_batch[i], to_scale)
t_b_batch[i] = cv2.resize(t_b_batch[i], to_scale)
t_f_batch[i] = cv2.resize(t_f_batch[i], to_scale)
mask_t_batch[i] = cv2.resize(mask_t_batch[i], to_scale, interpolation=cv2.INTER_NEAREST)

i_t_batch = np.stack(i_t_batch)
i_s_batch = np.stack(i_s_batch)
t_sk_batch = np.expand_dims(np.stack(t_sk_batch), axis = -1)
t_t_batch = np.stack(t_t_batch)
t_b_batch = np.stack(t_b_batch)
t_f_batch = np.stack(t_f_batch)
mask_t_batch = np.expand_dims(np.stack(mask_t_batch), axis = -1)

i_t_batch = i_t_batch.astype(np.float32) / 127.5 - 1.
i_s_batch = i_s_batch.astype(np.float32) / 127.5 - 1.
t_sk_batch = t_sk_batch.astype(np.float32) / 255.
t_t_batch = t_t_batch.astype(np.float32) / 127.5 - 1.
t_b_batch = t_b_batch.astype(np.float32) / 127.5 - 1.
t_f_batch = t_f_batch.astype(np.float32) / 127.5 - 1.
mask_t_batch = mask_t_batch.astype(np.float32) / 255.

yield [i_t_batch, i_s_batch, t_sk_batch, t_t_batch, t_b_batch, t_f_batch, mask_t_batch]

def get_input_data(data_dir = cfg.example_data_dir):

# get input data from dir
data_list = os.listdir(data_dir)
data_list = [data_name.split('_')[0] + '_' for data_name in data_list]
data_list = list(set(data_list))
res_list = []
for data_name in data_list:
i_t = cv2.imread(os.path.join(cfg.example_data_dir, data_name + 'i_t.png'))
i_s = cv2.imread(os.path.join(cfg.example_data_dir, data_name + 'i_s.png'))
h, w = i_t.shape[:2]
scale_ratio = cfg.data_shape[0] / h
to_h = cfg.data_shape[0]
to_w = int(round(int(w * scale_ratio) / 8)) * 8
to_scale = (to_w, to_h) # w first for cv2
i_t = cv2.resize(i_t, to_scale).astype(np.float32) / 127.5 - 1.
i_s = cv2.resize(i_s, to_scale).astype(np.float32) / 127.5 - 1.
i_t = np.expand_dims(i_t, axis = 0)
i_s = np.expand_dims(i_s, axis = 0)
res_list.append([i_t, i_s, (w, h), data_name]) # w first for cv2
return res_list
Binary file added examples/example/data.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/example/example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added examples/gen_logs/.gitignore
Empty file.
Binary file added examples/labels/001_i_s.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/001_i_t.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/002_i_s.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/002_i_t.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/003_i_s.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/003_i_t.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/004_i_s.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/004_i_t.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/005_i_s.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/005_i_t.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/006_i_s.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/006_i_t.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/007_i_s.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/007_i_t.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/008_i_s.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/labels/008_i_t.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added examples/results/.gitignore
Empty file.
103 changes: 103 additions & 0 deletions loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
SRNet - Editing Text in the Wild
Definition of loss functions.
Copyright (c) 2019 Netease Youdao Information Technology Co.,Ltd.
Licensed under the GPL License (see LICENSE for details)
Written by Yu Qian
"""

import tensorflow as tf
import cfg

def build_discriminator_loss(x, name = 'd_loss'):

x_true, x_pred = tf.split(x, 2, name = name + '_split')
d_loss = -tf.reduce_mean(tf.log(tf.clip_by_value(x_true, cfg.epsilon, 1.0)) \
+ tf.log(tf.clip_by_value(1.0 - x_pred, cfg.epsilon, 1.0)))
return d_loss

def build_dice_loss(x_t, x_o, name = 'dice_loss'):

intersection = tf.reduce_sum(x_t * x_o, axis = [1,2,3])
union = tf.reduce_sum(x_t, axis = [1,2,3]) + tf.reduce_sum(x_o, axis = [1,2,3])
return 1. - tf.reduce_mean((2. * intersection + cfg.epsilon)/(union + cfg.epsilon), axis = 0)

def build_l1_loss(x_t, x_o, name = 'l1_loss'):

return tf.reduce_mean(tf.abs(x_t - x_o))

def build_l1_loss_with_mask(x_t, x_o, mask, name = 'l1_loss'):

mask_ratio = 1. - tf.reduce_sum(mask) / tf.cast(tf.size(mask), tf.float32)
l1 = tf.abs(x_t - x_o)
return mask_ratio * tf.reduce_mean(l1 * mask) + (1. - mask_ratio) * tf.reduce_mean(l1 * (1. - mask))

def build_perceptual_loss(x, name = 'per_loss'):

l = []
for i, f in enumerate(x):
l.append(build_l1_loss(f[0], f[1], name = name + '_l1_' + str(i + 1)))
l = tf.stack(l, axis = 0, name = name + '_stack')
l = tf.reduce_sum(l, name = name + '_sum')
return l

def build_gram_matrix(x, name = 'gram_matrix'):

x_shape = tf.shape(x)
h, w, c = x_shape[1], x_shape[2], x_shape[3]
matrix = tf.reshape(x, shape = [-1, h * w, c])
gram = tf.matmul(matrix, matrix, transpose_a = True) / tf.cast(h * w * c, tf.float32)
return gram

def build_style_loss(x, name = 'style_loss'):

l = []
for i, f in enumerate(x):
f_shape = tf.size(f[0])
f_norm = 1. / tf.cast(f_shape, tf.float32)
gram_true = build_gram_matrix(f[0], name = name + '_gram_true_' + str(i + 1))
gram_pred = build_gram_matrix(f[1], name = name + '_gram_pred_' + str(i + 1))
l.append(f_norm * (build_l1_loss(gram_true, gram_pred, name = name + '_l1_' + str(i + 1))))
l = tf.stack(l, axis = 0, name = name + '_stack')
l = tf.reduce_sum(l, name = name + '_sum')
return l

def build_vgg_loss(x, name = 'vgg_loss'):

splited = []
for i, f in enumerate(x):
splited.append(tf.split(f, 2, name = name + '_split_' + str(i + 1)))
l_per = build_perceptual_loss(splited, name = name + '_per')
l_style = build_style_loss(splited, name = name + '_style')
return l_per, l_style

def build_gan_loss(x, name = 'gan_loss'):

x_true, x_pred = tf.split(x, 2, name = name + '_split')
gan_loss = -tf.reduce_mean(tf.log(tf.clip_by_value(x_pred, cfg.epsilon, 1.0)))
return gan_loss

def build_generator_loss(out_g, out_d, out_vgg, labels, name = 'g_loss'):

o_sk, o_t, o_b, o_f, mask_t = out_g
o_db, o_df = out_d
o_vgg = out_vgg
t_sk, t_t, t_b, t_f = labels

l_t_sk = cfg.lt_alpha * build_dice_loss(t_sk, o_sk, name = name + '_dice_loss')
l_t_l1 = build_l1_loss_with_mask(t_t, o_t, mask_t, name = name + '_lt_l1_loss')
l_t = l_t_l1 + l_t_sk

l_b_gan = build_gan_loss(o_db, name = name + '_lb_gan_loss')
l_b_l1 = cfg.lb_beta * build_l1_loss(t_b, o_b, name = name + '_lb_l1_loss')
l_b = l_b_gan + l_b_l1

l_f_gan = build_gan_loss(o_df, name = name + '_lf_gan_loss')
l_f_l1 = cfg.lf_theta_1 * build_l1_loss(t_f, o_f, name = name + '_lf_l1_loss')
l_f_vgg_per, l_f_vgg_style = build_vgg_loss(o_vgg, name = name + '_lf_vgg_loss')
l_f_vgg_per = cfg.lf_theta_2 * l_f_vgg_per
l_f_vgg_style = cfg.lf_theta_3 * l_f_vgg_style
l_f = l_f_gan + l_f_l1 + l_f_vgg_per + l_f_vgg_style

l = cfg.lt * l_t + cfg.lb * l_b + cfg.lf * l_f
return l, [l_t_sk, l_t_l1, l_b_gan, l_b_l1, l_f_gan, l_f_l1, l_f_vgg_per, l_f_vgg_style]
Loading

0 comments on commit 1c4ae30

Please sign in to comment.