-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/master'
- Loading branch information
Showing
9 changed files
with
973 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
*.npy | ||
runs/ | ||
|
||
# Created by https://www.gitignore.io/api/python,ipythonnotebook | ||
|
||
### Python ### | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*,cover | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
|
||
### IPythonNotebook ### | ||
# Temporary data | ||
.ipynb_checkpoints/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#Source code with the blog post at http://monik.in/a-noobs-guide-to-implementing-rnn-lstm-using-tensorflow/ | ||
import numpy as np | ||
import random | ||
from random import shuffle | ||
import os | ||
import pdb | ||
|
||
def load_data(): | ||
d_prefix='/home/junjuew/deep_learning/data' | ||
fns = [ | ||
'char_rnn_data_mat_10k.pkl', | ||
'char_rnn_label_mat_10k.pkl', | ||
'char_rnn_data_mat_1k_val.pkl', | ||
'char_rnn_label_mat_1k_val.pkl' | ||
] | ||
data=[] | ||
for idx, fn in enumerate(fns): | ||
data.append(np.load(os.path.join(d_prefix, fn))) | ||
|
||
trd=(data[0], data[1]) | ||
ted=(data[2], data[3]) | ||
return trd, ted | ||
|
||
print('training #: {} testing #: {}'.format(len(trd[0]), len(ted[0]))) | ||
|
||
def get_batch(data, s_idx, e_idx): | ||
return (data[0][s_idx:e_idx,:,:], data[1][s_idx:e_idx,:,]) | ||
|
||
def get_batch_num(data, batch_size): | ||
return int(data[0].shape[0]/ float(batch_size)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Overview | ||
This is an modified code for creating CNN networks for text classification tasks. | ||
To run, please modify load_custom_data in data_helper.py and its invocation in train.py to point to correct input and output files | ||
then use | ||
|
||
./train.py | ||
|
||
# Original README | ||
|
||
**[This code belongs to the "Implementing a CNN for Text Classification in Tensorflow" blog post.](http://www.wildml.com/2015/12/implementing-a-cnn-for-text-classification-in-tensorflow/)** | ||
|
||
It is slightly simplified implementation of Kim's [Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1408.5882) paper in Tensorflow. | ||
|
||
## Requirements | ||
|
||
- Python 3 | ||
- Tensorflow > 0.8 | ||
- Numpy | ||
|
||
## Training | ||
|
||
Print parameters: | ||
|
||
```bash | ||
./train.py --help | ||
``` | ||
|
||
``` | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
--embedding_dim EMBEDDING_DIM | ||
Dimensionality of character embedding (default: 128) | ||
--filter_sizes FILTER_SIZES | ||
Comma-separated filter sizes (default: '3,4,5') | ||
--num_filters NUM_FILTERS | ||
Number of filters per filter size (default: 128) | ||
--l2_reg_lambda L2_REG_LAMBDA | ||
L2 regularizaion lambda (default: 0.0) | ||
--dropout_keep_prob DROPOUT_KEEP_PROB | ||
Dropout keep probability (default: 0.5) | ||
--batch_size BATCH_SIZE | ||
Batch Size (default: 64) | ||
--num_epochs NUM_EPOCHS | ||
Number of training epochs (default: 100) | ||
--evaluate_every EVALUATE_EVERY | ||
Evaluate model on dev set after this many steps | ||
(default: 100) | ||
--checkpoint_every CHECKPOINT_EVERY | ||
Save model after this many steps (default: 100) | ||
--allow_soft_placement ALLOW_SOFT_PLACEMENT | ||
Allow device soft device placement | ||
--noallow_soft_placement | ||
--log_device_placement LOG_DEVICE_PLACEMENT | ||
Log placement of ops on devices | ||
--nolog_device_placement | ||
``` | ||
|
||
Train: | ||
|
||
```bash | ||
./train.py | ||
``` | ||
|
||
## Evaluating | ||
|
||
```bash | ||
./eval.py --eval_train --checkpoint_dir="./runs/1459637919/checkpoints/" | ||
``` | ||
|
||
Replace the checkpoint dir with the output from the training. To use your own data, change the `eval.py` script to load your data. | ||
|
||
|
||
## References | ||
|
||
- [Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1408.5882) | ||
- [A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1510.03820) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import numpy as np | ||
import re | ||
import itertools | ||
from collections import Counter | ||
import pdb | ||
import os | ||
|
||
def clean_str(string): | ||
""" | ||
Tokenization/string cleaning for all datasets except for SST. | ||
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py | ||
""" | ||
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) | ||
string = re.sub(r"\'s", " \'s", string) | ||
string = re.sub(r"\'ve", " \'ve", string) | ||
string = re.sub(r"n\'t", " n\'t", string) | ||
string = re.sub(r"\'re", " \'re", string) | ||
string = re.sub(r"\'d", " \'d", string) | ||
string = re.sub(r"\'ll", " \'ll", string) | ||
string = re.sub(r",", " , ", string) | ||
string = re.sub(r"!", " ! ", string) | ||
string = re.sub(r"\(", " \( ", string) | ||
string = re.sub(r"\)", " \) ", string) | ||
string = re.sub(r"\?", " \? ", string) | ||
string = re.sub(r"\s{2,}", " ", string) | ||
return string.strip().lower() | ||
|
||
def load_custom_data(pt): | ||
d_prefix='/home/junjuew/deep_learning/cnn/cnn-text-classification-tf/data/updated' | ||
fns = [ | ||
'query', | ||
'label' | ||
] | ||
data=[] | ||
for idx, fn in enumerate(fns): | ||
f_p=os.path.join(d_prefix, pt[idx].format(fn)) | ||
print('loading file at {}'.format(f_p)) | ||
raw_data=np.load(f_p) | ||
data.append(raw_data) | ||
shuffle_indices = np.random.permutation(np.arange(len(data[0]))) | ||
data[0][:] = data[0][shuffle_indices] | ||
data[1][:] = data[1][shuffle_indices] | ||
return data[0], data[1] | ||
|
||
|
||
def load_data_and_labels(positive_data_file, negative_data_file): | ||
""" | ||
Loads MR polarity data from files, splits the data into words and generates labels. | ||
Returns split sentences and labels. | ||
""" | ||
# Load data from files | ||
positive_examples = list(open(positive_data_file, "r").readlines()) | ||
positive_examples = [s.strip() for s in positive_examples] | ||
negative_examples = list(open(negative_data_file, "r").readlines()) | ||
negative_examples = [s.strip() for s in negative_examples] | ||
# Split by words | ||
x_text = positive_examples + negative_examples | ||
x_text = [clean_str(sent) for sent in x_text] | ||
# Generate labels | ||
positive_labels = [[0, 1] for _ in positive_examples] | ||
negative_labels = [[1, 0] for _ in negative_examples] | ||
y = np.concatenate([positive_labels, negative_labels], 0) | ||
return [x_text, y] | ||
|
||
|
||
def batch_iter(data, batch_size, num_epochs, shuffle=True): | ||
""" | ||
Generates a batch iterator for a dataset. | ||
""" | ||
data = np.array(data) | ||
data_size = len(data) | ||
num_batches_per_epoch = int(len(data)/batch_size) + 1 | ||
print('data_size: {}, num_batches_per_epoch: {}, epochs: {}'.format(data.shape, num_batches_per_epoch, num_epochs)) | ||
for epoch in range(num_epochs): | ||
# Shuffle the data at each epoch | ||
if shuffle: | ||
shuffle_indices = np.random.permutation(np.arange(data_size)) | ||
shuffled_data = data[shuffle_indices] | ||
else: | ||
shuffled_data = data | ||
for batch_num in range(num_batches_per_epoch): | ||
start_index = batch_num * batch_size | ||
end_index = min((batch_num + 1) * batch_size, data_size) | ||
yield shuffled_data[start_index:end_index] | ||
|
||
def gen_batch(data, batch_size, shuffle=True): | ||
""" | ||
Generates a batch iterator for a dataset. | ||
""" | ||
data = np.array(data) | ||
data_size = len(data) | ||
num_batches_per_epoch = int(len(data)/batch_size) + 1 | ||
# print('data_size: {}, num_batches_per_epoch: {}, '.format(data.shape, num_batches_per_epoch)) | ||
# Shuffle the data at each epoch | ||
if shuffle: | ||
shuffle_indices = np.random.permutation(np.arange(data_size)) | ||
shuffled_data = data[shuffle_indices] | ||
else: | ||
shuffled_data = data | ||
for batch_num in range(num_batches_per_epoch): | ||
start_index = batch_num * batch_size | ||
end_index = min((batch_num + 1) * batch_size, data_size) | ||
if start_index == end_index: | ||
continue | ||
yield shuffled_data[start_index:end_index] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#! /usr/bin/env python | ||
|
||
import tensorflow as tf | ||
import numpy as np | ||
import os | ||
import time | ||
import datetime | ||
import data_helpers | ||
from text_cnn import TextCNN | ||
from tensorflow.contrib import learn | ||
import csv | ||
|
||
# Parameters | ||
# ================================================== | ||
|
||
# Data Parameters | ||
tf.flags.DEFINE_string("positive_data_file", "./data/rt-polaritydata/rt-polarity.pos", "Data source for the positive data.") | ||
tf.flags.DEFINE_string("negative_data_file", "./data/rt-polaritydata/rt-polarity.neg", "Data source for the positive data.") | ||
|
||
# Eval Parameters | ||
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)") | ||
tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run") | ||
tf.flags.DEFINE_boolean("eval_train", False, "Evaluate on all training data") | ||
|
||
# Misc Parameters | ||
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") | ||
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") | ||
|
||
|
||
FLAGS = tf.flags.FLAGS | ||
FLAGS._parse_flags() | ||
print("\nParameters:") | ||
for attr, value in sorted(FLAGS.__flags.items()): | ||
print("{}={}".format(attr.upper(), value)) | ||
print("") | ||
|
||
# CHANGE THIS: Load data. Load your own data here | ||
if FLAGS.eval_train: | ||
x_raw, y_test = data_helpers.load_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file) | ||
y_test = np.argmax(y_test, axis=1) | ||
else: | ||
x_raw = ["a masterpiece four years in the making", "everything is off."] | ||
y_test = [1, 0] | ||
|
||
# Map data into vocabulary | ||
vocab_path = os.path.join(FLAGS.checkpoint_dir, "..", "vocab") | ||
vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_path) | ||
x_test = np.array(list(vocab_processor.transform(x_raw))) | ||
|
||
print("\nEvaluating...\n") | ||
|
||
# Evaluation | ||
# ================================================== | ||
checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) | ||
graph = tf.Graph() | ||
with graph.as_default(): | ||
session_conf = tf.ConfigProto( | ||
allow_soft_placement=FLAGS.allow_soft_placement, | ||
log_device_placement=FLAGS.log_device_placement) | ||
sess = tf.Session(config=session_conf) | ||
with sess.as_default(): | ||
# Load the saved meta graph and restore variables | ||
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) | ||
saver.restore(sess, checkpoint_file) | ||
|
||
# Get the placeholders from the graph by name | ||
input_x = graph.get_operation_by_name("input_x").outputs[0] | ||
# input_y = graph.get_operation_by_name("input_y").outputs[0] | ||
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] | ||
|
||
# Tensors we want to evaluate | ||
predictions = graph.get_operation_by_name("output/predictions").outputs[0] | ||
|
||
# Generate batches for one epoch | ||
batches = data_helpers.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False) | ||
|
||
# Collect the predictions here | ||
all_predictions = [] | ||
|
||
for x_test_batch in batches: | ||
batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0}) | ||
all_predictions = np.concatenate([all_predictions, batch_predictions]) | ||
|
||
# Print accuracy if y_test is defined | ||
if y_test is not None: | ||
correct_predictions = float(sum(all_predictions == y_test)) | ||
print("Total number of test examples: {}".format(len(y_test))) | ||
print("Accuracy: {:g}".format(correct_predictions/float(len(y_test)))) | ||
|
||
# Save the evaluation to a csv | ||
predictions_human_readable = np.column_stack((np.array(x_raw), all_predictions)) | ||
out_path = os.path.join(FLAGS.checkpoint_dir, "..", "prediction.csv") | ||
print("Saving evaluation to {0}".format(out_path)) | ||
with open(out_path, 'w') as f: | ||
csv.writer(f).writerows(predictions_human_readable) |
Oops, something went wrong.