Skip to content

Commit

Permalink
Merge pull request tensorflow#1432 from alexgorban/master
Browse files Browse the repository at this point in the history
Open source release of Attention OCR
  • Loading branch information
martinwicke authored May 2, 2017
2 parents 3a3c5b9 + 9beaea4 commit 4cc1fa0
Show file tree
Hide file tree
Showing 24 changed files with 3,193 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ running TensorFlow 0.12 or earlier, please
## Models
- [adversarial_text](adversarial_text): semi-supervised sequence learning with
adversarial training.
- [attention_ocr](attention_ocr): a model for real-world image text extraction.
- [autoencoder](autoencoder): various autoencoders.
- [cognitive_mapping_and_planning](cognitive_mapping_and_planning): implementation of a spatial memory based mapping and planning architecture for visual navigation.
- [compression](compression): compressing and decompressing images using a pre-trained Residual GRU network.
Expand Down
75 changes: 75 additions & 0 deletions attention_ocr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
## Attention-based Extraction of Structured Information from Street View Imagery

*A TensorFlow model for real-world image text extraction problems.*

This folder contains the code needed to train a new Attention OCR model on the
[FSNS dataset][FSNS] dataset to transcribe street names in France. You can
also use it to train it on your own data.

More details can be found in our paper:

["Attention-based Extraction of Structured Information from Street View
Imagery"](https://arxiv.org/abs/1704.03549)

## Contacts

Authors:
Zbigniew Wojna <[email protected]>,
Alexander Gorban <[email protected]>

Pull requests:
[alexgorban](https://github.com/alexgorban)

## Requirements

1. Installed TensorFlow library ([instructions][TF]).
2. At least 158Gb of free disk space to download FSNS dataset:

```
aria2c -c -j 20 -i ../street/python/fsns_urls.txt
```

3. 16Gb of RAM or more, 32Gb is recommended.
4. The train.py works with in both modes CPU and GPU, using GPU is preferable.
The GPU mode was tested with Titan X and GTX980.

[TF]: https://www.tensorflow.org/install/
[FSNS]: https://github.com/tensorflow/models/tree/master/street

## How to use this code

To run all unit tests:

```
python -m unittest discover -p '*_test.py'
```

To train from scratch:

```
python train.py
```

To train a model using a pre-trained inception weights as initialization:
```
wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
tar xf inception_v3_2016_08_28.tar.gz
python train.py --checkpoint_inception=inception_v3.ckpt
```

To fine tune the Attention OCR model using a checkpoint:

```
wget http://download.tensorflow.org/models/attention_ocr_2017_05_01.tar.gz
tar xf attention_ocr_2017_05_01.tar.gz
python train.py --checkpoint=model.ckpt-232572
```

## Disclaimer

This code is a modified version of the internal model we used for our paper.
Currently it reaches 82.71% full sequence accuracy after 215k steps of training.
The main difference between this version and the version used in the paper - for
the paper we used a distributed training with 50 GPU (K80) workers (asynchronous
updates), the provided checkpoint was created using this code after ~60 hours of
training on a single GPU (Titan X).
9 changes: 9 additions & 0 deletions attention_ocr/python/all_jobs.screenrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# A GPU/screen config to run all jobs for training and evaluation in parallel.
# Execute:
# source /path/to/your/virtualenv/bin/activate
# screen -R TF -c all_jobs.screenrc

screen -t train 0 python train.py --train_log_dir=workdir/train
screen -t eval_train 1 python eval.py --split_name=train --train_log_dir=workdir/train --eval_log_dir=workdir/eval_train
screen -t eval_test 2 python eval.py --split_name=test --train_log_dir=workdir/train --eval_log_dir=workdir/eval_test
screen -t tensorboard 3 tensorboard --logdir=workdir
149 changes: 149 additions & 0 deletions attention_ocr/python/common_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Define flags are common for both train.py and eval.py scripts."""
import sys

from tensorflow.python.platform import flags
import logging

import datasets
import model

FLAGS = flags.FLAGS

logging.basicConfig(
level=logging.DEBUG,
stream=sys.stderr,
format='%(levelname)s '
'%(asctime)s.%(msecs)06d: '
'%(filename)s: '
'%(lineno)d '
'%(message)s',
datefmt='%Y-%m-%d %H:%M:%S')


def define():
"""Define common flags."""
# yapf: disable
flags.DEFINE_integer('batch_size', 32,
'Batch size.')

flags.DEFINE_integer('crop_width', None,
'Width of the central crop for images.')

flags.DEFINE_integer('crop_height', None,
'Height of the central crop for images.')

flags.DEFINE_string('train_log_dir', '/tmp/attention_ocr/train',
'Directory where to write event logs.')

flags.DEFINE_string('dataset_name', 'fsns',
'Name of the dataset. Supported: fsns')

flags.DEFINE_string('split_name', 'train',
'Dataset split name to run evaluation for: test,train.')

flags.DEFINE_string('dataset_dir', None,
'Dataset root folder.')

flags.DEFINE_string('checkpoint', '',
'Path for checkpoint to restore weights from.')

flags.DEFINE_string('master',
'',
'BNS name of the TensorFlow master to use.')

# Model hyper parameters
flags.DEFINE_float('learning_rate', 0.004,
'learning rate')

flags.DEFINE_string('optimizer', 'momentum',
'the optimizer to use')

flags.DEFINE_string('momentum', 0.9,
'momentum value for the momentum optimizer if used')

flags.DEFINE_bool('use_augment_input', True,
'If True will use image augmentation')

# Method hyper parameters
# conv_tower_fn
flags.DEFINE_string('final_endpoint', 'Mixed_5d',
'Endpoint to cut inception tower')

# sequence_logit_fn
flags.DEFINE_bool('use_attention', True,
'If True will use the attention mechanism')

flags.DEFINE_bool('use_autoregression', True,
'If True will use autoregression (a feedback link)')

flags.DEFINE_integer('num_lstm_units', 256,
'number of LSTM units for sequence LSTM')

flags.DEFINE_float('weight_decay', 0.00004,
'weight decay for char prediction FC layers')

flags.DEFINE_float('lstm_state_clip_value', 10.0,
'cell state is clipped by this value prior to the cell'
' output activation')

# 'sequence_loss_fn'
flags.DEFINE_float('label_smoothing', 0.1,
'weight for label smoothing')

flags.DEFINE_bool('ignore_nulls', True,
'ignore null characters for computing the loss')

flags.DEFINE_bool('average_across_timesteps', False,
'divide the returned cost by the total label weight')
# yapf: enable


def get_crop_size():
if FLAGS.crop_width and FLAGS.crop_height:
return (FLAGS.crop_width, FLAGS.crop_height)
else:
return None


def create_dataset(split_name):
ds_module = getattr(datasets, FLAGS.dataset_name)
return ds_module.get_split(split_name, dataset_dir=FLAGS.dataset_dir)


def create_mparams():
return {
'conv_tower_fn':
model.ConvTowerParams(final_endpoint=FLAGS.final_endpoint),
'sequence_logit_fn':
model.SequenceLogitsParams(
use_attention=FLAGS.use_attention,
use_autoregression=FLAGS.use_autoregression,
num_lstm_units=FLAGS.num_lstm_units,
weight_decay=FLAGS.weight_decay,
lstm_state_clip_value=FLAGS.lstm_state_clip_value),
'sequence_loss_fn':
model.SequenceLossParams(
label_smoothing=FLAGS.label_smoothing,
ignore_nulls=FLAGS.ignore_nulls,
average_across_timesteps=FLAGS.average_across_timesteps)
}


def create_model(*args, **kwargs):
ocr_model = model.Model(mparams=create_mparams(), *args, **kwargs)
return ocr_model
Loading

0 comments on commit 4cc1fa0

Please sign in to comment.