From bed212b211dff68219b6f50f20a49256e391abee Mon Sep 17 00:00:00 2001 From: yu20103983 Date: Fri, 19 Apr 2019 21:46:14 +0800 Subject: [PATCH] write some comment --- Augment.py | 44 ++++++ FeatureCNN.py | 189 +++++++++++++++++++++++++ GRU_Att_Cov.py | 154 +++++++++++++++++++++ README.md | 0 WAP_Model.py | 330 ++++++++++++++++++++++++++++++++++++++++++++ config.py | 59 ++++++++ create_tfrecords.py | 163 ++++++++++++++++++++++ iterator_utils.py | 162 ++++++++++++++++++++++ read_tf_records.py | 233 +++++++++++++++++++++++++++++++ utils.py | 65 +++++++++ vocab_utils.py | 80 +++++++++++ 11 files changed, 1479 insertions(+) create mode 100644 Augment.py create mode 100644 FeatureCNN.py create mode 100644 GRU_Att_Cov.py create mode 100644 README.md create mode 100644 WAP_Model.py create mode 100644 config.py create mode 100644 create_tfrecords.py create mode 100644 iterator_utils.py create mode 100644 read_tf_records.py create mode 100644 utils.py create mode 100644 vocab_utils.py diff --git a/Augment.py b/Augment.py new file mode 100644 index 0000000..171e1a7 --- /dev/null +++ b/Augment.py @@ -0,0 +1,44 @@ +#coding=utf-8 +import tensorflow as tf + +def nonlinear(imageList, lower, upper): + with tf.name_scope('nonlinear') as scope: + factor = tf.random_uniform([], lower, upper) + + res = [] + for i in imageList: + res.append(tf.pow(i, factor)) + + return res + + +def randomNormal(imageList, stddev): + with tf.name_scope('randomNormal') as scope: + factor = tf.random_uniform([], 0, stddev) + + res = [] + for i in imageList: + res.append(i + tf.random_normal(tf.shape(i), mean=0.0, stddev=factor)) + + return res + + +def mirror(image): + uniform_random = tf.random_uniform([], 0, 1.0) + return tf.cond(uniform_random < 0.5, lambda: image, lambda: tf.reverse(image, axis=[2])) + + +def augment(image): + with tf.name_scope('augmentation') as scope: + image = nonlinear([image], 0.8, 1.2)[0] # 乘上一个随机因子 + + # image = mirror(image) # 镜像翻转 + + image = tf.image.random_contrast(image, lower=0.3, upper=1.3) # 随机调整对比度 + image = tf.image.random_brightness(image, max_delta=0.3) # 随机调整亮度 + + image = randomNormal([image], 0.025)[0] # 随机噪音 + + image = tf.clip_by_value(image, 0, 1.0) # 剪切 + + return image \ No newline at end of file diff --git a/FeatureCNN.py b/FeatureCNN.py new file mode 100644 index 0000000..57bf08d --- /dev/null +++ b/FeatureCNN.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +import tensorflow as tf +from config import cfg +from tensorflow.python.training.moving_averages import assign_moving_average + + +def tfVariable(dtype, shape, name, trainable=True): + return tf.Variable(tf.truncated_normal(dtype=dtype, shape=shape, mean=0, stddev=0.01), name=name, + trainable=trainable) + + +def tfVariable_ones(dtype, shape, name, trainable=True): + return tf.Variable(tf.ones(dtype=dtype, shape=shape), name=name, trainable=trainable) + + +def tfVariable_zeros(dtype, shape, name, trainable=True): + return tf.Variable(tf.zeros(dtype=dtype, shape=shape), name=name, trainable=trainable) + + +class FeatureCNN(): + def __init__(self): + with tf.variable_scope('featureCNN'): + # 第一组 + self.conv1_1_w = tfVariable(tf.float32, (5, 5, 1, 16), 'conv1_1_w') + self.scale1_1 = tfVariable_ones(tf.float32, 16, 'scale1_1') + self.shift1_1 = tfVariable_zeros(tf.float32, 16, 'shift1_1') + self.var1_1 = tfVariable_ones(tf.float32, 16, 'var1_1', False) + self.mean1_1 = tfVariable_zeros(tf.float32, 16, 'mean1_1', False) + conv1_1 = (self.conv1_1_w, self.scale1_1, self.shift1_1, self.var1_1, self.mean1_1) + # self.conv1_2_w = tfVariable(tf.float32, (3, 3, 32, 32), 'conv1_2_w') + # self.scale1_2 = tfVariable_ones(tf.float32, 32, 'scale1_2') + # self.shift1_2 = tfVariable_zeros(tf.float32, 32, 'shift1_2') + # self.var1_2 = tfVariable_ones(tf.float32, 32, 'var1_2', False) + # self.mean1_2 = tfVariable_zeros(tf.float32, 32, 'mean1_2', False) + # conv1_2 = (self.conv1_2_w, self.scale1_2, self.shift1_2) + # self.conv1_3_w = tfVariable(tf.float32, (3, 3, 32, 32), 'conv1_3_w') + # self.scale1_3 = tfVariable_ones(tf.float32, 32, 'scale1_3') + # self.shift1_3 = tfVariable_zeros(tf.float32, 32, 'shift1_3') + # conv1_3 = (self.conv1_3_w, self.scale1_3, self.shift1_3) + # self.conv1_4_w = tfVariable(tf.float32, (3, 3, 32, 32), 'conv1_4_w') + # self.scale1_4 = tfVariable_ones(tf.float32, 32, 'scale1_4') + # self.shift1_4 = tfVariable_zeros(tf.float32, 32, 'shift1_4') + # conv1_4 = (self.conv1_4_w, self.scale1_4, self.shift1_4) + # self.conv1 = (conv1_1, conv1_2, conv1_3, conv1_4) + # self.conv1 = (conv1_1, conv1_2) + self.conv1 = (conv1_1,) + + # 第二组 + self.conv2_1_w = tfVariable(tf.float32, (3, 3, 16, 16), 'conv2_1_w') + self.scale2_1 = tfVariable_ones(tf.float32, 16, 'scale2_1') + self.shift2_1 = tfVariable_zeros(tf.float32, 16, 'shift2_1') + self.var2_1 = tfVariable_ones(tf.float32, 16, 'var2_1', False) + self.mean2_1 = tfVariable_zeros(tf.float32, 16, 'mean2_1', False) + conv2_1 = (self.conv2_1_w, self.scale2_1, self.shift2_1, self.var2_1, self.mean2_1) + # self.conv2_2_w = tfVariable(tf.float32, (3, 3, 64, 64), 'conv2_2_w') + # self.scale2_2 = tfVariable_ones(tf.float32, 64, 'scale2_2') + # self.shift2_2 = tfVariable_zeros(tf.float32, 64, 'shift2_2') + # conv2_2 = (self.conv2_2_w, self.scale2_2, self.shift2_2) + # self.conv2_3_w = tfVariable(tf.float32, (3, 3, 64, 64), 'conv2_3_w') + # self.scale2_3 = tfVariable_ones(tf.float32, 64, 'scale2_3') + # self.shift2_3 = tfVariable_zeros(tf.float32, 64, 'shift2_3') + # conv2_3 = (self.conv2_3_w, self.scale2_3, self.shift2_3) + # self.conv2_4_w = tfVariable(tf.float32, (3, 3, 64, 64), 'conv2_4_w') + # self.scale2_4 = tfVariable_ones(tf.float32, 64, 'scale2_4') + # self.shift2_4 = tfVariable_zeros(tf.float32, 64, 'shift2_4') + # conv2_4 = (self.conv2_4_w, self.scale2_4, self.shift2_4) + # self.conv2 = (conv2_1, conv2_2, conv2_3, conv2_4) + # self.conv2 = (conv2_1, conv2_2) + self.conv2 = (conv2_1,) + + # 第三组 + self.conv3_1_w = tfVariable(tf.float32, (3, 3, 16, 32), 'conv3_1_w') + self.scale3_1 = tfVariable_ones(tf.float32, 32, 'scale3_1') + self.shift3_1 = tfVariable_zeros(tf.float32, 32, 'shift3_1') + self.var3_1 = tfVariable_ones(tf.float32, 32, 'var3_1', False) + self.mean3_1 = tfVariable_zeros(tf.float32, 32, 'mean3_1', False) + conv3_1 = (self.conv3_1_w, self.scale3_1, self.shift3_1, self.var3_1, self.mean3_1) + # self.conv3_2_w = tfVariable(tf.float32, (3, 3, 64, 64), 'conv3_2_w') + # self.scale3_2 = tfVariable_ones(tf.float32, 64, 'scale3_2') + # self.shift3_2 = tfVariable_zeros(tf.float32, 64, 'shift3_2') + # conv3_2 = (self.conv3_2_w, self.scale3_2, self.shift3_2) + # self.conv3_3_w = tfVariable(tf.float32, (3, 3, 64, 64), 'conv3_3_w') + # self.scale3_3 = tfVariable_ones(tf.float32, 64, 'scale3_3') + # self.shift3_3 = tfVariable_zeros(tf.float32, 64, 'shift3_3') + # conv3_3 = (self.conv3_3_w, self.scale3_3, self.shift3_3) + # self.conv3_4_w = tfVariable(tf.float32, (3, 3, 64, 64), 'conv3_4_w') + # self.scale3_4 = tfVariable_ones(tf.float32, 64, 'scale3_4') + # self.shift3_4 = tfVariable_zeros(tf.float32, 64, 'shift3_4') + # conv3_4 = (self.conv3_4_w, self.scale3_4, self.shift3_4) + # self.conv3 = (conv3_1, conv3_2, conv3_3, conv3_4) + # self.conv3 = (conv3_1, conv3_2) + self.conv3 = (conv3_1,) + + # 第四组 + self.conv4_1_w = tfVariable(tf.float32, (3, 3, 32, cfg.rnn_input_dimensions), 'conv4_1_w') + self.scale4_1 = tfVariable_ones(tf.float32, cfg.rnn_input_dimensions, 'scale4_1') + self.shift4_1 = tfVariable_zeros(tf.float32, cfg.rnn_input_dimensions, 'shift4_1') + self.var4_1 = tfVariable_ones(tf.float32, cfg.rnn_input_dimensions, 'var4_1', False) + self.mean4_1 = tfVariable_zeros(tf.float32, cfg.rnn_input_dimensions, 'mean4_1', False) + conv4_1 = (self.conv4_1_w, self.scale4_1, self.shift4_1, self.var4_1, self.mean4_1) + # self.conv4_2_w = tfVariable(tf.float32, (3, 3, cfg.rnn_input_dimensions, cfg.rnn_input_dimensions), 'conv4_2_w') + # self.scale4_2 = tfVariable_ones(tf.float32, cfg.rnn_input_dimensions, 'scale4_2') + # self.shift4_2 = tfVariable_zeros(tf.float32, cfg.rnn_input_dimensions, 'shift4_2') + # conv4_2 = (self.conv4_2_w, self.scale4_2, self.shift4_2) + # self.conv4_3_w = tfVariable(tf.float32, (3, 3, cfg.rnn_input_dimensions, cfg.rnn_input_dimensions), 'conv4_3_w') + # self.scale4_3 = tfVariable_ones(tf.float32, cfg.rnn_input_dimensions, 'scale4_3') + # self.shift4_3 = tfVariable_zeros(tf.float32, cfg.rnn_input_dimensions, 'shift4_3') + # conv4_3 = (self.conv4_3_w, self.scale4_3, self.shift4_3) + # self.conv4_4_w = tfVariable(tf.float32, (3, 3, cfg.rnn_input_dimensions, cfg.rnn_input_dimensions), 'conv4_4_w') + # self.scale4_4 = tfVariable_ones(tf.float32, cfg.rnn_input_dimensions, 'scale4_4') + # self.shift4_4 = tfVariable_zeros(tf.float32, cfg.rnn_input_dimensions, 'shift4_4') + # conv4_4 = (self.conv4_4_w, self.scale4_4, self.shift4_4) + # self.conv4 = (conv4_1, conv4_2, conv4_3, conv4_4) + # self.conv4 = (conv4_1, conv4_2) + self.conv4 = ((conv4_1),) + + def __call__(self, input, is_training=True, bn_mv = True): + def Batch_Norn(s_input, scale, shift, moving_variance, moving_mean, axis=[0, 1, 2], eps=1e-05, decay=0.9, + name=None): + def mean_var_with_update(): + means, variances = tf.nn.moments(s_input, axes=axis, name='moments') + with tf.variable_scope('ass_m_a_%s' % name if name else 'ass_m_a', reuse=bn_mv): + with tf.control_dependencies([assign_moving_average(moving_mean, means, decay, zero_debias=False), + assign_moving_average(moving_variance, variances, decay, zero_debias=False)]): + return tf.identity(means), tf.identity(variances) + + if bn_mv: + mean = moving_mean + var = moving_variance + else: + mean, var = mean_var_with_update() + + return tf.nn.batch_normalization(s_input, mean, var, shift, scale, eps, name=name) + + def CNNBlock(s_input, W, pool_size=[2, 2], bn=True, dropout=None, name=None): + out = [s_input] + for i, w_s_s in enumerate(W): + if bn: + w, scale, shift, var, mean = w_s_s + else: + w, b = w_s_s[0] + conv = tf.nn.conv2d(out[-1], w, [1, 1, 1, 1], "SAME", name=(name + "_%d_c2d" % i if name else name)) + out.append(conv) + + conv = tf.nn.relu(conv, name=(name + "_%d_relu" % i if name else name)) + out.append(conv) + if bn: + conv = Batch_Norn(conv, scale, shift, var, mean, name=(name + "_%d_BN" % i if name else name)) + # conv = tf.nn.bias_add(conv, shift, name=(name + "_%d_Bias" % i if name else name)) + else: + conv = tf.nn.bias_add(conv, b, name=(name + "_%d_Bias" % i if name else name)) + out.append(conv) + + if is_training and dropout: + conv = tf.nn.dropout(conv, keep_prob=dropout, name=(name + "_%d_dp" % i if name else name)) + out.append(conv) + pool = tf.nn.max_pool(out[-1], [1, pool_size[0], pool_size[1], 1], [1, pool_size[0], pool_size[1], 1], + "SAME", name=(name + "_pool" if name else name)) + out.append(pool) + return out + + def FCBlock(s_input, W, bn=True, act=True, dropout=None, name=None): + out = [s_input] + for i, w_s_s in enumerate(W): + if bn: + w, scale, shift, var, mean = w_s_s + else: + w, b = w_s_s + + fc = tf.matmul(out[-1], w, name=(name + "_%d_fc" % i if name else name)) + out.append(fc) + if act: fc = tf.nn.relu(fc, name=(name + "_%d_relu" % i if name else name)) + out.append(fc) + if bn: + fc = Batch_Norn(fc, scale, shift, var, mean, axis=[0], name=(name + "_%d_BN" % i if name else name)) + else: + fc = tf.nn.bias_add(fc, b, name=(name + "_%d_Bias" % i if name else name)) + out.append(fc) + + if is_training and dropout: + fc = tf.nn.dropout(fc, keep_prob=dropout, name=(name + "_%d_dp" % i if name else name)) + out.append(fc) + return out + + cnn_1 = CNNBlock(input, self.conv1, pool_size=[3, 3], name='b1')[-1] + cnn_2 = CNNBlock(cnn_1, self.conv2, pool_size=[2, 2], name='b2')[-1] + cnn_3 = CNNBlock(cnn_2, self.conv3, pool_size=[2, 2], name='b3')[-1] + cnn_4 = CNNBlock(cnn_3, self.conv4, pool_size=[2, 2], dropout=cfg.keep_prob, name='b4')[-1] + return cnn_4 diff --git a/GRU_Att_Cov.py b/GRU_Att_Cov.py new file mode 100644 index 0000000..a072ef7 --- /dev/null +++ b/GRU_Att_Cov.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +import tensorflow as tf +from config import cfg + +def tfVariable(dtype, shape, name): + return tf.Variable(tf.truncated_normal(dtype=dtype, shape=shape, mean=0, stddev=0.01), name=name) +class GRU_Att_Cov(): + def __init__(self, tgt_table_size): + self.input_dimensions = cfg.rnn_input_dimensions + self.hidden_size = cfg.rnn_hidden_size + self.attention_size = cfg.attention_size + self.coverage_size = cfg.coverage_size + self.embedding_size = cfg.tgt_embedding_size + self.project_size = tgt_table_size + with tf.variable_scope('tgt_embedding'): + self.embedding = tfVariable(tf.float32, (self.project_size, self.embedding_size), 'embedding') + with tf.variable_scope('gru') as scope: + # Weights for input vectors of shape (input_dimensions, hidden_size) + self.wr = tfVariable(tf.float32, (self.input_dimensions, self.hidden_size), 'wr') + self.wz = tfVariable(tf.float32, (self.input_dimensions, self.hidden_size), 'wz') + self.wh = tfVariable(tf.float32, (self.input_dimensions, self.hidden_size), 'wh') + # Weights for hidden vectors of shape (hidden_size, hidden_size) + self.ur = tfVariable(tf.float32, (self.hidden_size, self.hidden_size), 'ur') + self.uz = tfVariable(tf.float32, (self.hidden_size, self.hidden_size), 'uz') + self.uh = tfVariable(tf.float32, (self.hidden_size, self.hidden_size), 'uh') + # Weights for embedding vectors of shape (hidden_size, hidden_size) + self.yr = tfVariable(tf.float32, (self.embedding_size, self.hidden_size), 'yr') + self.yz = tfVariable(tf.float32, (self.embedding_size, self.hidden_size), 'yz') + self.yh = tfVariable(tf.float32, (self.embedding_size, self.hidden_size), 'yh') + # Biases for hidden vectors of shape (hidden_size,) + self.br = tfVariable(tf.float32, (self.hidden_size,), 'br') + self.bz = tfVariable(tf.float32, (self.hidden_size,), 'bz') + self.bh = tfVariable(tf.float32, (self.hidden_size,), 'bh') + with tf.variable_scope('att'): + # Weights for attention mechanism hidden vectors of shape (hidden_size, attention_size) + self.wa = tfVariable(tf.float32, (self.hidden_size, self.attention_size), 'wa') + # Weights for attention mechanism hidden input of shape (input_dimensions, attention_size) + self.ua = tfVariable(tf.float32, (self.input_dimensions, self.attention_size), 'ua') + # Weights for attention mechanism of shape (attention_size) + self.va = tfVariable(tf.float32, (self.attention_size,), 'va') + with tf.variable_scope('coverage'): + conv1_w = tfVariable(tf.float32, (11, 11, 1, self.coverage_size), 'conv1_w') + conv1_b = tfVariable(tf.float32, (self.coverage_size,), 'conv1_b') + def ConvCoverage(input): + conv1 = tf.nn.conv2d(input, conv1_w, [1, 1, 1, 1], "SAME", name='conv1') + conv1_addb = tf.nn.bias_add(conv1, conv1_b) + conv1_act = tf.nn.relu(conv1_addb) + return conv1_act + # Conv weights for Coverage mechanism of shape () + self.convCoverage = ConvCoverage + self.uc = tfVariable(tf.float32, (self.coverage_size, self.attention_size), 'uc') + with tf.variable_scope('project'): + self.ph = tfVariable(tf.float32, (self.hidden_size, self.embedding_size), 'ph') + self.pc = tfVariable(tf.float32, (self.input_dimensions, self.embedding_size), 'pc') + self.po = tfVariable(tf.float32, (self.embedding_size, self.project_size), 'po') + self.pb = tfVariable(tf.float32, (self.project_size,), 'pb') + + def __call__(self, cnn_out, iterator, is_training=True, tgt_sos_id = None): + source = cnn_out + h = tf.shape(source, name='shape_h')[1] + w = tf.shape(source, name='shape_w')[2] + batch_size = tf.shape(source, name='shape_batch_size')[0] + source_reshape = tf.reshape(source, (batch_size, -1, self.input_dimensions), name='reshape_1') + source_dist_major = tf.transpose(source_reshape, (1, 0 ,2), name='transpose_1') + + # A little hack (to obtain the same shape as the input matrix) to define the initial hidden state h_0 and initial cover state cover_0 + h_0 = tf.matmul(source_dist_major[0], tf.zeros((self.input_dimensions, self.hidden_size), tf.float32), name='matmul_1') + cover_0 = tf.zeros((batch_size, h, w, 1), tf.float32) + attn_dist_0 = tf.zeros((batch_size, h, w, 1), tf.float32) + inds_t_1 = tf.fill((batch_size,), tgt_sos_id) + logits_0 = tf.zeros((batch_size, self.project_size), tf.float32) + + t_0_state = (h_0, cover_0, logits_0, inds_t_1, attn_dist_0) + + padding_mask = tf.sequence_mask(tf.cast(tf.ceil(tf.divide(iterator.source_sequence_length, cfg.cnn_strike)), tf.int32), w, name='sequence_mask') + padding_mask_tile = tf.tile(padding_mask, [1, h], name='tile_1') + padding_mask_tile = tf.cast(padding_mask_tile, tf.float32) + + def masked_attention(e): + attn_dist = tf.nn.softmax(e, name='softmax_1') + attn_dist *= padding_mask_tile # apply mask + masked_sums = tf.reduce_sum(attn_dist, axis=1, name='reduce_sum_1') + return attn_dist / tf.reshape(masked_sums, [-1, 1], name='reshape_2') # re-normalize + + def forward_pass(t_1_state, t_input): + h_t_1, cover_t_1, logits_t_1, inds_t_1, _ = t_1_state # + # WAP (13) + F = self.convCoverage(cover_t_1) + F_flatten = tf.reshape(F, (batch_size, -1, self.coverage_size), name='reshape_3') + F_flatten_dist_major = tf.transpose(F_flatten, (1, 0, 2), name='transpose_3') + # (14) + + def att_Clcu(a_t_1_state, a_t_input): + slice_source, slice_F = a_t_input + ha = tf.matmul(h_t_1, self.wa, name='matmul_2') + \ + tf.matmul(slice_source, self.ua, name='matmul_3') + tf.matmul(slice_F, self.uc, name='matmul_4') + a_t = tf.matmul(tf.tanh(ha), tf.expand_dims(self.va, -1), name='matmul_5') + return tf.reshape(a_t, [-1]) + + e_is = tf.scan(fn=att_Clcu, elems=[source_dist_major, F_flatten_dist_major], initializer=tf.zeros((batch_size,), tf.float32), name='scan_1') + + e_is = tf.stack(e_is, name='stack_1') + e_is_transpose = tf.transpose(e_is) + + # (9) - (10) + # masked attention + attn_dist_t_flatten = masked_attention(e_is_transpose) + attn_dist_t = tf.reshape(attn_dist_t_flatten, (-1, h, w, 1), name='reshape_4') + c_t = tf.reduce_sum(source_reshape * tf.expand_dims(attn_dist_t_flatten, -1), axis=1) + + if is_training: + Ey_t_1 = tf.nn.embedding_lookup(self.embedding, t_input) + else: + Ey_t_1 = tf.nn.embedding_lookup(self.embedding, inds_t_1) + + # GRU WAP (4)-(7) + z_t = tf.nn.sigmoid(tf.matmul(Ey_t_1, self.yz, name='matmul_6') + + tf.matmul(h_t_1, self.uz, name='matmul_7') + tf.matmul(c_t, self.wz, name='matmul_8') + self.bz) + r_t = tf.nn.sigmoid(tf.matmul(Ey_t_1, self.yr, name='matmul_9') + + tf.matmul(h_t_1, self.ur, name='matmul_10') + tf.matmul(c_t, self.wr, name='matmul_11') + self.br) + h_proposal = tf.tanh(tf.matmul(Ey_t_1, self.yh, name='matmul_12') + + tf.matmul(tf.multiply(r_t, h_t_1), self.uh, name='matmul_13') + tf.matmul(c_t, self.wh, name='matmul_14') + self.bh) + h_t = tf.multiply(1 - z_t, h_t_1) + tf.multiply(z_t, h_proposal) + + # WAP (12) + cover_t = cover_t_1 + attn_dist_t + + # (11) + logits_t = tf.matmul(Ey_t_1 + tf.matmul(h_t, self.ph, name='matmul_15') + tf.matmul(c_t, self.pc, name='matmul_16'), self.po, name='matmul_17') + self.pb + softMax_t = tf.nn.softmax(logits_t, -1) + max_inds_t = tf.argmax(softMax_t, -1, output_type=tf.int32) + + t_state = (h_t, cover_t, logits_t, max_inds_t, attn_dist_t) + return t_state + + if is_training: + elems = tf.transpose(iterator.target_input) + else: + assert tgt_sos_id is not None + tgt_maxLen = int(cfg.tgt_max_len * cfg.allow_grow_ratio) + tgt_maxLen = max(tgt_maxLen, cfg.tgt_max_len + 1) + + elems = tf.fill((tgt_maxLen, batch_size), tgt_sos_id) + + self.t_state = tf.scan(fn=forward_pass, elems=elems, initializer=t_0_state, name='scan_2') + logits = tf.transpose(self.t_state[2], (1, 0, 2)) + indes = tf.transpose(self.t_state[3], (1, 0)) + attn_dists = tf.reshape(tf.transpose(self.t_state[4], (1, 0, 2, 3, 4)), (batch_size, -1, h, w)) + return logits, indes, attn_dists + + + + + diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/WAP_Model.py b/WAP_Model.py new file mode 100644 index 0000000..cb4a9c6 --- /dev/null +++ b/WAP_Model.py @@ -0,0 +1,330 @@ +# -*- coding: utf-8 -*- +# Tensorflow implementation WAP_Model for paper " +# Watch, attend and parse: An end-to-end neural network based approach to handwritten mathematical expression recognition" +# at https://www.sciencedirect.com/science/article/pii/S0031320317302376 +# author yuyufeng e-mail: yufeng_yu@sina.com +import tensorflow as tf +from config import cfg +import vocab_utils +import utils +import traceback +import time +from read_tf_records import * +import numpy as np +import shutil + +from GRU_Att_Cov import GRU_Att_Cov +from FeatureCNN import FeatureCNN + + +def get_collections_from_scope(scope_name): + return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope_name) + + +def get_local_time(): + timeStamp = int(time.time()) + timeArray = time.localtime(timeStamp) + return time.strftime("%Y_%m_%d_%H_%M_%S", timeArray) + + +class WAP(): + def __init__(self, is_training=True, checkPoint_path=None): + self.graph = tf.Graph() + self.is_training = is_training + with self.graph.as_default(): + ano_data_set = os.path.join(cfg.data_set, cfg.ano_data_set) + vocab_file = os.path.join(ano_data_set, cfg.tgt_vocab_file) + vocab_size, vocab_file = vocab_utils.check_vocab(vocab_file, out_dir=cfg.out_dir, sos=cfg.sos, eos=cfg.eos, + unk=cfg.unk) + + self.tgt_vocab_table = vocab_utils.create_vocab_tables(vocab_file) + self.reverse_tgt_vocab_table = vocab_utils.index_to_string_table_from_file( + vocab_file, default_value=cfg.unk) + + self.tgt_sos_id = tf.cast(self.tgt_vocab_table.lookup(tf.constant(cfg.sos)), tf.int32) + self.tgt_eos_id = tf.cast(self.tgt_vocab_table.lookup(tf.constant(cfg.eos)), tf.int32) + + if is_training: + # train_src_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.train_src_dataset)) + # train_tgt_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.train_tgt_dataset)) + self.init_iter_train, self.iterator_train = get_iterator(cfg.train_tf_filename, self.tgt_vocab_table, + self.tgt_sos_id, self.tgt_eos_id, augment=True) + + # vaild_src_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.vaild_src_dataset)) + # vaild_tgt_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.vaild_tgt_dataset)) + self.init_iter_vaild, self.iterator_vaild = get_iterator(cfg.vaild_tf_filename, self.tgt_vocab_table, + self.tgt_sos_id, self.tgt_eos_id) + + else: + self.source = tf.placeholder(tf.float32, (None, None), name='source') + batch_source = tf.expand_dims(tf.expand_dims(self.source, axis=0), axis=-1) + iterator_source = normalize_input_img(batch_source) + self.source_sequence_length = tf.constant(tf.shape(iterator_source)[2], tf.int32) + self.iterator = BatchedInput(source=iterator_source, + target_input=None, target_output=None, + source_sequence_length=self.source_sequence_length, + target_sequence_length=None) + + self.featureCNN = FeatureCNN() + self.gru_att_cov = GRU_Att_Cov(vocab_size) #词表size + + if is_training: + if cfg.outer_batch_size: + outer_loss = 0 + with tf.variable_scope('outer_batch_size') as scope: + for i in range(cfg.outer_batch_size): + if i > 0: + scope.reuse_variables() + self.cnn_out_train = self.featureCNN(self.iterator_train.source, True, False) + self.logits_train, _, self.attn_dists_train = self.gru_att_cov(self.cnn_out_train, + self.iterator_train, True, + self.tgt_sos_id) + outer_loss += self._loss(self.logits_train, self.iterator_train) + + self.loss_train = outer_loss / cfg.outer_batch_size + else: + self.cnn_out_train = self.featureCNN(self.iterator_train.source, True, False) + self.logits_train, _, self.attn_dists_train = self.gru_att_cov(self.cnn_out_train, + self.iterator_train, True, + self.tgt_sos_id) + self.loss_train = self._loss(self.logits_train, self.iterator_train) + + self.global_step = tf.Variable(0, name='global_step', trainable=False) + self.learning_rate = tf.train.exponential_decay(cfg.startLr, self.global_step, cfg.decay_steps, + cfg.decay_rate) + optimizer = tf.train.AdadeltaOptimizer(self.learning_rate) + self.train_op = optimizer.minimize(self.loss_train, global_step=self.global_step) + + self.cnn_out_vaild = self.featureCNN(self.iterator_vaild.source, True) + self.logits_vaild, _, _ = self.gru_att_cov(self.cnn_out_vaild, self.iterator_vaild, True, self.tgt_sos_id) + self.loss_vaild = self._loss(self.logits_vaild, self.iterator_vaild) + + self.cnn_out_vaild_infer = self.featureCNN(self.iterator_vaild.source, False) + _, self.infer_indes_vaild, self.infer_attn_dists_vaild = self.gru_att_cov(self.cnn_out_vaild_infer, + self.iterator_vaild, False, + self.tgt_sos_id) + self.infer_lookUpTgt_vaild = self.reverse_tgt_vocab_table.lookup(tf.to_int64(self.infer_indes_vaild)) + + self.accuracy_vaild = self._acc(self.infer_indes_vaild, self.iterator_vaild.target_output) + self.train_lookUpTgt_vaild = self.reverse_tgt_vocab_table.lookup( + tf.to_int64(self.iterator_vaild.target_output)) + + self.train_summary, self.vaild_summary = self._summary() + else: + self.cnn_out = self.featureCNN(self.iterator.source, is_training) + _, self.infer_indes, self.infer_attn_dists = self.gru_att_cov(self.cnn_out, self.iterator, False, + self.tgt_sos_id) + self.infer_lookUpTgt = self.reverse_tgt_vocab_table.lookup(tf.to_int64(self.infer_indes)) + + self.init = [tf.global_variables_initializer(), tf.tables_initializer()] + self.saver = tf.train.Saver() + self.sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) + if not is_training: + self.sess.run(self.init) + self.saver.restore(self.sess, checkPoint_path) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.sess.close() + + def _loss(self, logits, iterator): + """Compute optimization loss.""" + target_output = iterator.target_output + + max_time = tf.shape(target_output)[1] + crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=target_output, logits=logits) + target_weights = tf.sequence_mask( + iterator.target_sequence_length, max_time, dtype=logits.dtype) + return tf.reduce_mean(tf.div(tf.reduce_sum( + crossent * target_weights, axis=1), tf.cast(iterator.target_sequence_length, tf.float32))) + + def _summary(self): + train_summary = [] + train_summary.append(tf.summary.scalar('train/learning_rate', self.learning_rate)) + train_summary.append(tf.summary.scalar('train/loss', self.loss_train)) + + vaild_summary = [] + vaild_summary.append(tf.summary.scalar('vaild/loss', self.loss_vaild)) + vaild_summary.append(tf.summary.scalar('vaild/acc', self.accuracy_vaild)) + + source_shape = tf.shape(self.iterator_train.source) + source_h = source_shape[1] + source_w = source_shape[2] + + infer_attn_dists_train_shape = tf.shape(self.attn_dists_train) + batch_size = infer_attn_dists_train_shape[0] + attn_times = infer_attn_dists_train_shape[1] + cnn_strike_h = infer_attn_dists_train_shape[2] + cnn_strike_w = infer_attn_dists_train_shape[3] + + infer_attn_dists_train_reshape = tf.reshape(self.attn_dists_train, (-1, cnn_strike_h, cnn_strike_w, 1)) + attn_dist_reSize = tf.image.resize_bicubic(infer_attn_dists_train_reshape, (source_h, source_w)) + attn_dist_reshape = tf.reshape(attn_dist_reSize, (batch_size, -1, source_w, 1)) + attn_dist_LitUp = attn_dist_reshape * 0.9 + 0.1 # 对比度和亮度调整 + attn_dist_LitUp = tf.clip_by_value(attn_dist_LitUp, 0, 1.) + + source_tile = tf.tile(self.iterator_train.source, [1, attn_times, 1, 1]) + source_mask_dist = tf.multiply(attn_dist_LitUp, source_tile) + source_mask_dist = tf.clip_by_value(source_mask_dist, 0, 1.) + train_summary.append(tf.summary.image('train/source_mask_dist', source_mask_dist, 8)) + + return tf.summary.merge(train_summary), tf.summary.merge(vaild_summary) + + def _acc(self, infer_indes, target_output): + infer_indes_slice = tf.slice(infer_indes, tf.zeros((tf.rank(infer_indes),), tf.int32), + tf.shape(target_output)) + correct_prediction = tf.reduce_all(tf.equal(tf.to_int32(target_output), infer_indes_slice), axis=1) + return tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + + def train(self, train_writer_path, log_f, restore_folder=None, restore_name=None): + if not self.is_training: return + sess = self.sess + saver = self.saver + if train_writer_path: + train_writer = tf.summary.FileWriter(train_writer_path, sess.graph) + time.clock() + sess.run(self.init) + if restore_folder: + restore_path = os.path.join(restore_folder, restore_name) + saver.restore(sess, restore_path) + utils.print_out('load from %s' % restore_path, log_f) + # with self.graph.as_default(): + # gru_att_cov_variables = [] + # gru_att_cov_variables.extend(get_collections_from_scope('tgt_embedding')) + # gru_att_cov_variables.extend(get_collections_from_scope('gru')) + # gru_att_cov_variables.extend(get_collections_from_scope('att')) + # gru_att_cov_variables.extend(get_collections_from_scope('coverage')) + # gru_att_cov_variables.extend(get_collections_from_scope('project')) + # init_gru_att_cov_variables = tf.variables_initializer(var_list=gru_att_cov_variables) + # sess.run(init_gru_att_cov_variables) + # utils.print_out('random init gru_att_cov variables', log_f) + global_step = 0 + epoch = 0 + utils.print_hparams(cfg, f=log_f) + sess.run(self.init_iter_train) # 初始化训练输入 + utils.print_out('init_iter_train', log_f) + sess.run(self.init_iter_vaild) # 初始化验证输入 + utils.print_out('init_iter_vaild', log_f) + learning_rate = 0 + loss = 0 + try: + utils.print_out("start at %s" % get_local_time(), log_f) + while epoch < cfg.epochs: + epoch += 1 + i_steps = 0 + + while i_steps < cfg.each_steps: + try: + _, loss, global_step, learning_rate, summary = \ + sess.run([self.train_op, self.loss_train, self.global_step, + self.learning_rate, self.train_summary]) + i_steps += 1 + except tf.errors.OutOfRangeError: + sess.run(self.init_iter_train) + if global_step % cfg.print_frq == 0: + utils.print_out('epoch %d, step %d, gloSp %d, lr %.4f, loss %.4f' + % (epoch, i_steps, global_step, learning_rate, loss), log_f) + if global_step % cfg.summary_frq == 0: + summary = sess.run(self.train_summary) + if train_writer_path: train_writer.add_summary(summary, global_step=global_step) + + if epoch % cfg.val_frq == 0: + + val_count = 0 + val_loss = 0 + val_acc = 0 + val_edit_dist = 0 + true_sample_words = [''] + pred_sample_words = [''] + i_loss = 0 + i_acc = 0 + while val_count < cfg.val_steps: + try: + i_loss, i_acc, true_sample_words, pred_sample_words, summary = \ + sess.run([self.loss_vaild, self.accuracy_vaild, self.train_lookUpTgt_vaild, + self.infer_lookUpTgt_vaild, self.vaild_summary]) + val_count += 1 + except tf.errors.OutOfRangeError: + sess.run(self.init_iter_vaild) + + val_loss += i_loss + val_acc += i_acc + c_val_edit_dist = [] + for t, p in zip(true_sample_words, pred_sample_words): + edit_dist = utils.normal_leven(t, p) + c_val_edit_dist.append(edit_dist) + c_val_edit_dist = sum(c_val_edit_dist) / float(len(c_val_edit_dist)) + val_edit_dist += c_val_edit_dist + + val_acc /= val_count + val_loss /= val_count + val_edit_dist /= val_count + + timeStamp = int(time.time()) + timeArray = time.localtime(timeStamp) + styleTime = time.strftime("%Y_%m_%d_%H_%M_%S", timeArray) + + utils.print_out('%s ### val loss %.4f, acc %.4f, edit_dist %.4f' + % (styleTime, val_loss, val_acc, val_edit_dist), log_f) + if train_writer_path: train_writer.add_summary(summary, global_step=global_step) + test_show_size = min(cfg.test_show_size, len(true_sample_words)) + for i in range(test_show_size): + str_tr = ''.join(true_sample_words[i]) + str_pd = ''.join(pred_sample_words[i]) + utils.print_out(" ## true: %s" % (str_tr), log_f) + utils.print_out(" pred: %s" % (str_pd), log_f) + if epoch % cfg.save_frq == 0 and train_writer_path: + checkPoint_path = os.path.join(train_writer_path, "checkPoint.model") + saver.save(sess, checkPoint_path, global_step=global_step) + utils.print_out( + " global step %d, check point save to %s-%d" % ( + global_step, checkPoint_path, global_step), log_f) + + except Exception as e: + utils.print_out( + "!!!! Interrupt ## end training, global step %d" % ( + global_step), log_f) + if len(e.args) > 0: + utils.print_out("An error occurred. {}".format(e.args[-1]), log_f) + traceback.print_exc() + + finally: + if train_writer_path: + checkPoint_path = os.path.join(train_writer_path, "end_checkPoint.model") + saver.save(sess, checkPoint_path, global_step=global_step) + utils.print_out( + " end training, global step %d, check point save to %s-%d" % ( + global_step, checkPoint_path, global_step), log_f) + utils.print_out("end at %s" % get_local_time(), log_f) + return epoch + + def predict(self, img): + if self.is_training: return + assert np.rank(img) == 2 + tgt = self.sess.run([self.infer_lookUpTgt], feed_dict={self.source: img}) + return tgt + + +if __name__ == '__main__': + import os + + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + # os.environ["CUDA_VISIBLE_DEVICES"] = "0" + timeStamp = get_local_time() + out_dir = os.path.join(cfg.out_dir, "%s" % timeStamp) + + if not os.path.exists(out_dir): + os.makedirs(out_dir) + log_file = os.path.join(out_dir, "log") + log_f = tf.gfile.GFile(log_file, mode="a") + utils.print_out("# log_file=%s" % log_file, log_f) + wap = WAP(True) + start = time.time() + epoch = wap.train(out_dir, log_f, cfg.load_preTrain_model_folder, cfg.load_preTrain_model_name) + end = time.time() + time_consum = (end - start) / 3600 + if cfg.debug and time_consum < 0.5: + print('log folder %s removed' % out_dir) + shutil.rmtree(out_dir) + else: + shutil.move(out_dir, os.path.join(cfg.out_dir, "%03d_%04d_%s" % (time_consum * 10, epoch, timeStamp))) diff --git a/config.py b/config.py new file mode 100644 index 0000000..806b4d5 --- /dev/null +++ b/config.py @@ -0,0 +1,59 @@ +import tensorflow as tf + +cfg = tf.contrib.training.HParams( + # cfg + debug=True, + bn=True, + debug_dir='./out/fillin_answer_less4words_1_31_img/debug', + # out_dir='./out/fillin_answer_less4words_1_31_img', + # data_set='./data_set/fillin_answer_less4words_1_31_img', + out_dir='./out/HANG_RECG3', + data_set='./data_set/HANG_RECG3', + # data_set='E:\\yu_collection\\hang_recg_gene_Anan\\Anan\\fillin_answer_less4words_1_31_img', + # out_dir='E:\\yu_collection\\hang_recg_gene_Anan\\Anan\\fillin_answer_less4words_1_31_img', + ano_data_set='ano', + imgFolderName='jpg', + tgt_vocab_file='recg.vocb', + train_src_dataset='train.ind', + train_tgt_dataset='train.recg', + train_tf_filename='train.tfrecords', + vaild_src_dataset='vaild.ind', + vaild_tgt_dataset='vaild.recg', + vaild_tf_filename='vaild.tfrecords', + # load_preTrain_model_folder='./out/fillin_answer_less4words_1_31_img/2018_02_06_22_14_24', + # load_preTrain_model_name='checkPoint.model-120000', + load_preTrain_model_folder=None, + load_preTrain_model_name=None, + unk="", + sos="", + eos="", + src_fixed_height=96, + src_max_len=1920, + tgt_max_len=50, + allow_grow_ratio=1, + cnn_strike=24, + num_buckets=0, + default_bucket_width=10, + + rnn_input_dimensions=32, + rnn_hidden_size=128, + attention_size=32, + coverage_size=32, + tgt_embedding_size=32, + + outer_batch_size=None, + batch_size=1, + keep_prob=0.8, + # train param + startLr=0.1, + decay_steps=5000, + decay_rate=0.96, + epochs=800, + each_steps=300, + print_frq=20, + summary_frq=50, + val_frq=1, + save_frq=50, + val_steps=40, + test_show_size=4, +) diff --git a/create_tfrecords.py b/create_tfrecords.py new file mode 100644 index 0000000..e73950e --- /dev/null +++ b/create_tfrecords.py @@ -0,0 +1,163 @@ +#coding=utf-8 +__author__ = 'Administrator' +import os +from config import cfg +import tensorflow as tf +import sys +import numpy as np +from skimage import io + + +def read_txt(file_name): + with open(file_name, 'r', encoding='utf-8') as f_i: + return f_i.readlines() + +def int64_feature(data): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[data])) + +def bytes_feature(data): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data])) + +def create_tfrecords(tf_filename, src_dataset, tgt_dataset): + + img_names = read_txt(os.path.join(cfg.data_set, cfg.ano_data_set, src_dataset)) + labels = read_txt(os.path.join(cfg.data_set, cfg.ano_data_set, tgt_dataset)) + + + def normalize_input_img(img): + shape = tf.shape(img) + + def f1(shape, img): + shape = tf.cast(shape, tf.float32) + width = tf.cast(tf.multiply(tf.div(shape[1], shape[0]), cfg.src_fixed_height, "imgWidth"), tf.int32) + + return tf.image.resize_images(img, [cfg.src_fixed_height, width], + method=tf.image.ResizeMethod.BICUBIC) + + img = tf.cond(tf.not_equal(shape[0], cfg.src_fixed_height), lambda: f1(shape, img), lambda: img) + img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 + + return img + + def readSrcImg(img_folder): + src_holder = tf.placeholder(tf.string, ()) + filenames = tf.string_join([img_folder, '/', src_holder, ".jpg"], separator="") + img = tf.cast(tf.image.decode_jpeg(tf.read_file(filenames), channels=1), tf.float32) + norm_img = tf.reshape(normalize_input_img(img), (cfg.src_fixed_height, -1)) + return src_holder, norm_img + + img_folder = os.path.join(cfg.data_set, cfg.imgFolderName) + src_holder, norm_img = readSrcImg(img_folder) + with tf.Session() as sess: + vocb = read_txt(os.path.join(cfg.data_set, cfg.ano_data_set, cfg.tgt_vocab_file)) + vocb = [ivocb.strip() for ivocb in vocb] + with tf.python_io.TFRecordWriter( + os.path.join(cfg.data_set, cfg.ano_data_set, tf_filename)) as tfrecord_writer: + for i, img_name, label in zip(range(len(img_names)), img_names, labels): + try: + image_data = sess.run(norm_img, feed_dict={src_holder: img_name.strip()}) + label = label.strip() + if len(label) == 0: + recg_ind = [] + else: + recg_ind = [vocb.index(word) for word in label.split(' ')] + + label = bytes(label, encoding="utf8") + height, width = image_data.shape + image_data = image_data.tobytes() + example = tf.train.Example( + # 属性名称到取值的字典 + features=tf.train.Features(feature={"image/encoded": bytes_feature(image_data), + 'image/height': int64_feature(height), + 'image/width': int64_feature(width), + "label/value": bytes_feature(label), + "label/ind": tf.train.Feature( + int64_list=tf.train.Int64List(value=recg_ind))})) + + tfrecord_writer.write(example.SerializeToString()) + sys.stdout.write('%d of %d : %s' % (i, len(img_names), img_name)) + sys.stdout.flush() + except Exception as e: + print("Error: ", e) + + + print('\nFinished converting the dataset!') + +def mod_vocb_tfrecords(tf_filename, tf_filename_2, tgt_dataset): + def read_tfrecord(): + dataset = tf.contrib.data.TFRecordDataset(os.path.join(cfg.data_set, cfg.ano_data_set, tf_filename)) + def parser_tfrecord(record): + parsed = tf.parse_single_example(record, + features={ + 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), + 'image/height': tf.FixedLenFeature((), tf.int64, default_value=0), + 'image/width': tf.FixedLenFeature((), tf.int64, default_value=0), + 'label/value': tf.VarLenFeature(tf.string), + 'label/ind': tf.VarLenFeature(tf.int64), + }) + + img = parsed['image/encoded'] # 直接采用bytes编码 + height = parsed['image/height'] + width = parsed['image/width'] + + return img, height, width + dataset = dataset.map(parser_tfrecord, num_threads=4, output_buffer_size=6000) + dataset = dataset.batch(128) + batched_iter = dataset.make_initializable_iterator() + img, height, width = batched_iter.get_next() + return batched_iter.initializer, (img, height, width) + + batched_iter, data = read_tfrecord() + img_op, height_op, width_op = data + sess = tf.Session() + sess.run([tf.global_variables_initializer()]) + sess.run(batched_iter) + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(sess=sess, coord=coord) + try: + vocb = read_txt(os.path.join(cfg.data_set, cfg.ano_data_set, cfg.tgt_vocab_file)) + vocb = [ivocb.strip() for ivocb in vocb] + labels = read_txt(os.path.join(cfg.data_set, cfg.ano_data_set, tgt_dataset)) + + with tf.python_io.TFRecordWriter( + os.path.join(cfg.data_set, cfg.ano_data_set, tf_filename_2)) as tfrecord_writer: + count = 0 + while not coord.should_stop(): + img, height, width = sess.run([img_op, height_op, width_op]) + for i_img, i_height, i_width in zip(img, height, width): + save_img =sess.run(tf.clip_by_value((tf.reshape(tf.decode_raw(tf.constant(i_img, tf.string), tf.float32), (i_height, i_width)) + 0.5), 0, 1.) * 255.) + io.imsave(os.path.join(cfg.debug_dir, '%d.jpg'%count), save_img) + label = labels[count].strip() + if len(label) == 0: + recg_ind = [] + else: + recg_ind = [vocb.index(word) for word in label.split(' ')] + try: + example = tf.train.Example( + # 属性名称到取值的字典 + features=tf.train.Features(feature={"image/encoded": bytes_feature(i_img), + 'image/height': int64_feature(height), + 'image/width': int64_feature(width), + "label/value": bytes_feature(label), + "label/ind": tf.train.Feature( + int64_list=tf.train.Int64List(value=recg_ind))})) + + tfrecord_writer.write(example.SerializeToString()) + sys.stdout.write('%d of %d' % (count, len(labels))) + sys.stdout.flush() + count += 1 + except Exception as e: + print("Error: ", e) + + except tf.errors.OutOfRangeError: + print('finished') + finally: + coord.request_stop() + coord.join(threads) + sess.close() + +if __name__ == '__main__': + create_tfrecords(cfg.train_tf_filename, cfg.train_src_dataset, cfg.train_tgt_dataset) + create_tfrecords(cfg.vaild_tf_filename, cfg.vaild_src_dataset, cfg.vaild_tgt_dataset) + # mod_vocb_tfrecords(cfg.train_tf_filename, cfg.train_tf_filename + 'mod', cfg.train_tgt_dataset) + # mod_vocb_tfrecords(cfg.vaild_tf_filename, cfg.vaild_tf_filename + 'mod', cfg.vaild_tgt_dataset) \ No newline at end of file diff --git a/iterator_utils.py b/iterator_utils.py new file mode 100644 index 0000000..3a40eb1 --- /dev/null +++ b/iterator_utils.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""For loading data into NMT models.""" +from __future__ import print_function + +import collections +from config import cfg + +import tensorflow as tf + +__all__ = ["BatchedInput", "get_iterator"] + + +# NOTE(ebrevdo): When we subclass this, instances' __dict__ becomes empty. +class BatchedInput(collections.namedtuple("BatchedInput", + ("source", + "target_input", + "target_output", + "source_sequence_length", + "target_sequence_length" + ))): + pass + +def normalize_input_img(img): + shape = tf.shape(img) + + def f1(shape, img): + shape = tf.cast(shape, tf.float32) + width = tf.cast(tf.multiply(tf.div(shape[1], shape[0]), cfg.src_fixed_height, "imgWidth"), tf.int32) + + return tf.image.resize_images(img, [cfg.src_fixed_height, width], + method=tf.image.ResizeMethod.BICUBIC) + + img = tf.cond(tf.not_equal(shape[0], cfg.src_fixed_height), lambda: f1(shape, img), lambda: img) + img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 + return img + +def readSrcImg(path, src): + filenames = tf.string_join([path, src, ".jpg"], separator="") + img = tf.cast(tf.image.decode_jpeg(tf.read_file(filenames), channels=1), tf.float32) + return normalize_input_img(img) + +def get_iterator(src_dataset, + tgt_dataset, + tgt_vocab_table, + tgt_sos_id, + tgt_eos_id, + num_threads=4, + output_buffer_size=120000): + src_tgt_dataset = tf.contrib.data.Dataset.zip((src_dataset, tgt_dataset)) + src_tgt_dataset = src_tgt_dataset.shuffle(output_buffer_size) + + path = tf.string_join([cfg.data_set, "/", cfg.imgFolderPath, "/"], separator="") + src_tgt_dataset = src_tgt_dataset.map( + lambda src, tgt: (readSrcImg(path, src), tgt), + num_threads=num_threads,output_buffer_size=output_buffer_size) + + # tf.string_split: Split elements of source based on delimiter into a SparseTensor. + src_tgt_dataset = src_tgt_dataset.map( + lambda src, tgt: (src, tf.string_split([tgt]).values), + num_threads=num_threads,output_buffer_size=output_buffer_size) + + # Filter zero length input sequences. + src_tgt_dataset = src_tgt_dataset.filter( + lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) + + # Convert the word strings to ids. Word strings that are not in the + # vocab get the lookup table's default_value integer. + src_tgt_dataset = src_tgt_dataset.map( + lambda src, tgt: (src, + tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), + num_threads=num_threads,output_buffer_size=output_buffer_size) + + if cfg.src_max_len: + src_tgt_dataset = src_tgt_dataset.map( + lambda src, tgt: (src[:,:cfg.src_max_len, :], tgt), + num_threads=num_threads,output_buffer_size=output_buffer_size) + if cfg.tgt_max_len: + src_tgt_dataset = src_tgt_dataset.map( + lambda src, tgt: (src, tgt[:cfg.tgt_max_len]), + num_threads=num_threads,output_buffer_size=output_buffer_size) + # Create a tgt_input prefixed with and a tgt_output suffixed with . + src_tgt_dataset = src_tgt_dataset.map( + lambda src, tgt: (src, + tf.concat(([tgt_sos_id], tgt), 0), + tf.concat((tgt, [tgt_eos_id]), 0)), + num_threads=num_threads,output_buffer_size=output_buffer_size) + + # Add in the word counts. Subtract one from the target to avoid counting + # the target_input tag (resp. target_output tag). + src_tgt_dataset = src_tgt_dataset.map( + lambda src, tgt_in, tgt_out: ( + src, tgt_in, tgt_out, tf.shape(src)[1], tf.size(tgt_in)), + # src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)), + num_threads=num_threads,output_buffer_size=output_buffer_size) + # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...) + def batching_func(x): + return x.padded_batch( + cfg.batch_size, + # The first three entries are the source and target line rows; + # these have unknown-length vectors. The last two entries are + # the source and target row sizes; these are scalars. + padded_shapes=(tf.TensorShape([cfg.src_fixed_height, None, 1]), # src + tf.TensorShape([None]), # tgt_input + tf.TensorShape([None]), # tgt_output + tf.TensorShape([]), # src_len + tf.TensorShape([])), # tgt_len + # Pad the source and target sequences with eos tokens. + # (Though notice we don't generally need to do this since + # later on we will be masking out calculations past the true sequence. + padding_values=(0., # src + tgt_eos_id, # tgt_input + tgt_eos_id, # tgt_output + 0, + 0)) # tgt_len -- unused + if cfg.num_buckets > 1: + + def key_func(unused_1, unused_2, unused_3, src_len, tgt_len): + # Calculate bucket_width by maximum source sequence length. + # Pairs with length [0, bucket_width) go to bucket 0, length + # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length + # over ((num_bucket-1) * bucket_width) words all go into the last bucket. + if cfg.src_max_len: + bucket_width = (cfg.src_max_len + cfg.num_buckets - 1) // cfg.num_buckets + else: + bucket_width = 10 + + # Bucket sentence pairs by the length of their source sentence and target + # sentence. + bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) + return tf.to_int64(tf.minimum(cfg.num_buckets, bucket_id)) + def reduce_func(unused_key, windowed_data): + return batching_func(windowed_data) + batched_dataset = src_tgt_dataset.group_by_window( + key_func=key_func, reduce_func=reduce_func, window_size=cfg.batch_size) + else: + batched_dataset = batching_func(src_tgt_dataset) + + batched_iter = batched_dataset.make_initializable_iterator() + (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len, tgt_seq_len) = ( + batched_iter.get_next()) + batchedInput = BatchedInput( + source=src_ids, + target_input=tgt_input_ids, + target_output=tgt_output_ids, + source_sequence_length=src_seq_len, + target_sequence_length=tgt_seq_len) + return batched_iter.initializer, batchedInput diff --git a/read_tf_records.py b/read_tf_records.py new file mode 100644 index 0000000..78d935e --- /dev/null +++ b/read_tf_records.py @@ -0,0 +1,233 @@ +#coding=utf-8 + +import tensorflow as tf +import os +import collections +import numpy as np +import Augment +from config import cfg +import vocab_utils + + +class BatchedInput(collections.namedtuple("BatchedInput", + ("source", + "target_input", + "target_output", + "source_sequence_length", + "target_sequence_length" + ))): + pass + + +def normalize_input_img(img): + shape = tf.shape(img) + + def f1(shape, img): + shape = tf.cast(shape, tf.float32) + width = tf.cast(tf.multiply(tf.div(shape[1], shape[0]), cfg.src_fixed_height, "imgWidth"), tf.int32) + + return tf.image.resize_images(img, [cfg.src_fixed_height, width], + method=tf.image.ResizeMethod.BICUBIC) + + img = tf.cond(tf.not_equal(shape[0], cfg.src_fixed_height), lambda: f1(shape, img), lambda: img) + img = tf.cast(img, tf.float32) * (1. / 255) + + return img + +def get_iterator(tf_filename, tgt_vocab_table, + tgt_sos_id, + tgt_eos_id, repeat=None, + num_threads=4, + output_buffer_size=120000, augment=False): + dataset = tf.contrib.data.TFRecordDataset(os.path.join(cfg.data_set, cfg.ano_data_set, tf_filename)) + + def parser_tfrecord(record): + parsed = tf.parse_single_example(record, + features={ + 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), + 'image/height': tf.FixedLenFeature((), tf.int64, default_value=0), + 'image/width': tf.FixedLenFeature((), tf.int64, default_value=0), + 'label/value': tf.VarLenFeature(tf.string), + }) + + # img = tf.cast(tf.image.decode_jpeg(parsed['image/encoded'], channels=1), tf.float32) # 保存的时候先按jpeg编码 + img = tf.decode_raw(parsed['image/encoded'], tf.float32) #直接采用bytes编码 + height = tf.cast(parsed['image/height'], tf.int32) + width = tf.cast(parsed['image/width'], tf.int32) + img = tf.reshape(img, (height, width, 1)) + img = tf.clip_by_value((img + 0.5), 0, 1.)* 255. + img = normalize_input_img(img) + if augment: img = Augment.augment(img) + label = tf.sparse_tensor_to_dense(parsed['label/value'], default_value='') + + return img, label + + dataset = dataset.map(parser_tfrecord, num_threads=num_threads, output_buffer_size=output_buffer_size) + dataset = dataset.map( + lambda src, tgt: (src, tf.string_split(tgt).values), + num_threads=num_threads, output_buffer_size=output_buffer_size) + + dataset = dataset.filter( + lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) + dataset = dataset.map( + lambda src, tgt: (src, tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), + num_threads=num_threads, output_buffer_size=output_buffer_size) + if cfg.src_max_len: + # dataset = dataset.map( + # lambda src, tgt: (src[:, :cfg.src_max_len, :], tgt), + # num_threads=num_threads, output_buffer_size=output_buffer_size) + dataset = dataset.filter( + lambda src, tgt: tf.shape(src)[1] <= cfg.src_max_len) + if cfg.tgt_max_len: + dataset = dataset.filter( + lambda src, tgt: tf.size(tgt) <= cfg.tgt_max_len) + # dataset = dataset.map( + # lambda src, tgt: (src, tgt[:cfg.tgt_max_len]), + # num_threads=num_threads, output_buffer_size=output_buffer_size) + + # 对一些相似字符进行替换 + # def repalce_some_label(labels): + # def replace_int_label(a, rep, case, *o_case): + # condition = tf.equal(a, case) + # for i_case in o_case: + # condition = tf.logical_or(condition, tf.equal(a, i_case)) + # case_true = tf.multiply(tf.ones_like(a, tf.int32), rep) + # case_false = a + # a_m = tf.where(condition, case_true, case_false) + # return a_m + # + # labels = replace_int_label(labels, 47, 21, 61) # C,c替换为括号 + # labels = replace_int_label(labels, 6, 29, 72) # O,o替换为0 + # labels = replace_int_label(labels, 69, 26) # L替换为l + # labels = replace_int_label(labels, 79, 34) # V替换为v + # labels = replace_int_label(labels, 80, 35, 54) # X,乘以替换为x + # labels = replace_int_label(labels, 47, 40) # 角替换为小于 + # labels = replace_int_label(labels, 76, 32) # S替换为s + # return labels + # + # dataset = dataset.map(lambda src, tgt: (src, repalce_some_label(tgt))) + + dataset = dataset.map( + lambda src, tgt: (src, tf.concat(([tgt_sos_id], tgt), 0), tf.concat((tgt, [tgt_eos_id]), 0)), + num_threads=num_threads, output_buffer_size=output_buffer_size) + dataset = dataset.map( + lambda src, tgt_in, tgt_out: ( + src, tgt_in, tgt_out, tf.shape(src)[1], tf.size(tgt_in)), + num_threads=num_threads, output_buffer_size=output_buffer_size) + + def batching_func(x): + return x.padded_batch( + cfg.batch_size, + padded_shapes=(tf.TensorShape([cfg.src_fixed_height, None, 1]), # src + tf.TensorShape([None]), # tgt_input + tf.TensorShape([None]), # tgt_output + tf.TensorShape([]), # src_len + tf.TensorShape([])), # tgt_len + padding_values=(0., # src + tgt_eos_id, # tgt_input + tgt_eos_id, # tgt_output + 0, + 0)) # tgt_len -- unused + + if cfg.num_buckets > 1: + def key_func(unused_1, unused_2, unused_3, src_len, tgt_len): + if cfg.src_max_len: + bucket_width = (cfg.src_max_len + cfg.num_buckets - 1) // cfg.num_buckets + else: + bucket_width = cfg.default_bucket_width + bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) + return tf.to_int64(tf.minimum(cfg.num_buckets, bucket_id)) + + def reduce_func(unused_key, windowed_data): + return batching_func(windowed_data) + + batched_dataset = dataset.group_by_window( + key_func=key_func, reduce_func=reduce_func, window_size=cfg.batch_size) + else: + batched_dataset = batching_func(dataset) + batched_dataset = batched_dataset.shuffle(buffer_size=6000) + if repeat: batched_dataset = batched_dataset.repeat(repeat) + batched_iter = batched_dataset.make_initializable_iterator() + (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len, tgt_seq_len) = ( + batched_iter.get_next()) + batchedInput = BatchedInput( + source=src_ids, + target_input=tgt_input_ids, + target_output=tgt_output_ids, + source_sequence_length=src_seq_len, + target_sequence_length=tgt_seq_len) + return batched_iter.initializer, batchedInput + + +def dense2sparse(dense): + indices, values = zip(*[([i, j], val) + for i, row in enumerate(dense) for j, val in enumerate(row)]) + max_len = max([len(row) for row in dense]) + shape = [len(dense), max_len] + sparse = (np.array(indices), np.array(values), np.array(shape)) + return sparse + + +def dense2sparse_second(sequences): + indices = [] + values = [] + for n, seq in enumerate(sequences): + indices.extend(zip([n] * len(seq), range(len(seq)))) + values.extend(seq) + + indices = np.asarray(indices, dtype=np.int64) + values = np.asarray(values, dtype=np.int32) + shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64) + sparse = (indices, values, shape) + return sparse + + +def main(_): + ano_data_set = os.path.join(cfg.data_set, cfg.ano_data_set) + vocab_file = os.path.join(ano_data_set, cfg.tgt_vocab_file) + + with tf.Graph().as_default(): + vocab_size, vocab_file = vocab_utils.check_vocab(vocab_file, out_dir=cfg.out_dir, sos=cfg.sos, eos=cfg.eos, + unk=cfg.unk) + + tgt_vocab_table = vocab_utils.create_vocab_tables(vocab_file) + reverse_tgt_vocab_table = vocab_utils.index_to_string_table_from_file( + vocab_file, default_value=cfg.unk) + + tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(cfg.sos)), tf.int32) + tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(cfg.eos)), tf.int32) + iter, batch_input = get_iterator(cfg.vaild_tf_filename, tgt_vocab_table, tgt_sos_id,tgt_eos_id) + lookUpTgt = reverse_tgt_vocab_table.lookup(tf.to_int64(batch_input.target_output)) + sess = tf.Session() + sess.run([tf.global_variables_initializer(), tf.tables_initializer()]) + sess.run(iter) + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(sess=sess, coord=coord) + step = 0 + try: + while True: + try: + while not coord.should_stop(): + src, tgt_output, src_seq_len, tgt_seq_len = \ + sess.run([batch_input.source, lookUpTgt, batch_input.source_sequence_length, batch_input.target_sequence_length]) + if np.isnan(np.max(src)) or np.isnan(np.min(src)): + print('get a nan') + exit(1) + if np.any(np.less(src, 0.)): + print('get a fushu') + exit(1) + print('run one') + step += 1 + except tf.errors.OutOfRangeError: + print('check finished') + exit(1) + sess.run(iter) + except KeyboardInterrupt: + print('interrupt') + finally: + coord.request_stop() + coord.join(threads) + sess.close() + +if __name__ == '__main__': + tf.app.run() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..f1198aa --- /dev/null +++ b/utils.py @@ -0,0 +1,65 @@ +import sys +import time +def print_time(start_time, f=None): + """Take a start time, print elapsed duration, and return a new time.""" + s = "time %ds, %s." % ((time.time() - start_time), time.ctime()) + print(s) + if f: + f.write(s.encode("utf-8")) + f.write(b"\n") + sys.stdout.flush() + return time.time() + + +def print_out(s, f=None, new_line=True): + """Similar to print but with support to flush and output to a file.""" + if isinstance(s, bytes): + s = s.decode("utf-8") + + if f: + f.write(s.encode("utf-8")) + if new_line: + f.write(b"\n") + + # stdout + # print(s.encode("utf-8"), end="", file=sys.stdout) + print(s.encode("utf-8")) + # if new_line: + # sys.stdout.write("\n") + sys.stdout.flush() + + + +def print_hparams(hparams, skip_patterns=None, f=None): + """Print hparams, can skip keys based on pattern.""" + values = hparams.values() + for key in sorted(values.keys()): + if not skip_patterns or all( + [skip_pattern not in key for skip_pattern in skip_patterns]): + print_out(" %s=%s" % (key, str(values[key])), f) + + +def normal_leven(str1, str2): + len_str1 = len(str1) + 1 + len_str2 = len(str2) + 1 + # create matrix + matrix = [0 for n in range(len_str1 * len_str2)] + # init x axis + for i in range(len_str1): + matrix[i] = i + # init y axis + for j in range(0, len(matrix), len_str1): + if j % len_str1 == 0: + matrix[j] = j // len_str1 + + for i in range(1, len_str1): + for j in range(1, len_str2): + if str1[i - 1] == str2[j - 1]: + cost = 0 + else: + cost = 1 + matrix[j * len_str1 + i] = min(matrix[(j - 1) * len_str1 + i] + 1, + matrix[j * len_str1 + (i - 1)] + 1, + matrix[(j - 1) * len_str1 + (i - 1)] + cost) + + return matrix[-1] diff --git a/vocab_utils.py b/vocab_utils.py new file mode 100644 index 0000000..65a7c48 --- /dev/null +++ b/vocab_utils.py @@ -0,0 +1,80 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utility to handle vocabularies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import codecs +import os +import tensorflow as tf + +from tensorflow.python.ops import lookup_ops + +import utils + + +UNK = "" +SOS = "" +EOS = "" +UNK_ID = 0 + + +def check_vocab(vocab_file, out_dir, sos=None, eos=None, unk=None): + """Check if vocab_file doesn't exist, create from corpus_file.""" + if tf.gfile.Exists(vocab_file): + utils.print_out("# Vocab file %s exists" % vocab_file) + vocab = [] + with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f: + vocab_size = 0 + for word in f: + vocab_size += 1 + vocab.append(word.strip()) + + # Verify if the vocab starts with unk, sos, eos + # If not, prepend those tokens & generate a new vocab file + if not unk: unk = UNK + if not sos: sos = SOS + if not eos: eos = EOS + assert len(vocab) >= 3 + if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos: + utils.print_out("The first 3 vocab words [%s, %s, %s]" + " are not [%s, %s, %s]" % + (vocab[0], vocab[1], vocab[2], unk, sos, eos)) + vocab = [unk, sos, eos] + vocab + vocab_size += 3 + new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file)) + with codecs.getwriter("utf-8")(tf.gfile.GFile(new_vocab_file, "wb")) as f: + for word in vocab: + f.write("%s\n" % word) + vocab_file = new_vocab_file + else: + raise ValueError("%s vocab_file does not exist."%vocab_file) + + vocab_size = len(vocab) + return vocab_size, vocab_file + + +def create_vocab_tables(tgt_vocab_file): + """Creates vocab tables for tgt_vocab_file.""" + tgt_vocab_table = lookup_ops.index_table_from_file( + tgt_vocab_file, default_value=UNK_ID) + return tgt_vocab_table + +def index_to_string_table_from_file(vocabulary_file, default_value): + return lookup_ops.index_to_string_table_from_file( + vocabulary_file, default_value=default_value)