Skip to content

Commit

Permalink
add embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilian Weng committed Sep 14, 2017
1 parent f7820a4 commit 6d5020f
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 56 deletions.
59 changes: 49 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pandas as pd
import pprint

import tensorflow as tf
Expand All @@ -8,6 +9,7 @@
from model import LstmRNN

flags = tf.app.flags
flags.DEFINE_integer("stock_count", 50, "Stock count [50]")
flags.DEFINE_integer("input_size", 1, "Input size [1]")
flags.DEFINE_integer("num_steps", 30, "Num of steps [30]")
flags.DEFINE_integer("num_layers", 1, "Num of layer [1]")
Expand All @@ -18,52 +20,89 @@
flags.DEFINE_float("learning_rate_decay", 0.99, "Decay rate of learning rate. [0.99]")
flags.DEFINE_integer("init_epoch", 5, "Num. of epoches considered as early stage. [5]")
flags.DEFINE_integer("max_epoch", 500, "Total training epoches. [500]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_integer("embed_size", None, "If provided, use embedding vector of this size. [None]")
flags.DEFINE_string("checkpoint_dir", "checkpoints", "Directory name to save the checkpoints [checkpoints]")
flags.DEFINE_string("plot_dir", "images", "Directory name to save plots [images]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
FLAGS = flags.FLAGS

pp = pprint.PrettyPrinter()

if not os.path.exists("logs"):
os.mkdir("logs")


def show_all_variables():
model_vars = tf.trainable_variables()
slim.model_analyzer.analyze_vars(model_vars, print_info=True)


def load_sp500(input_size, num_steps, k=None, target_symbol=None):
# Load metadata of s & p 500 stocks
info = pd.read_csv("data/constituents-financials.csv")
info = info.rename(columns={col: col.lower().replace(' ', '_') for col in info.columns})
info['file_exists'] = info['symbol'].map(lambda x: os.path.exists("data/{}.csv".format(x)))
print info['file_exists'].value_counts().to_dict()

info = info[info['file_exists'] == True].reset_index(drop=True)

info = info.sort('market_cap', ascending=False)

if k is not None:
info = info.head(k)

if target_symbol is not None:
assert target_symbol in info['symbol']
info = info[info['symbol'] == target_symbol]

# Generate embedding meta file
info[['symbol', 'sector']].to_csv(os.path.join("logs/metadata.tsv"), sep='\t', index=False)

return [
StockDataSet(row['symbol'],
input_size=input_size,
num_steps=num_steps,
test_ratio=0.1)
for _, row in info.iterrows()]


def main(_):
pp.pprint(flags.FLAGS.__flags)

if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)

if not os.path.exists(FLAGS.plot_dir):
os.makedirs(FLAGS.plot_dir)

# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth = True

with tf.Session(config=run_config) as sess:
rnn_model = LstmRNN(
sess,
FLAGS.stock_count,
lstm_size=FLAGS.lstm_size,
num_layers=FLAGS.num_layers,
num_steps=FLAGS.num_steps,
input_size=FLAGS.input_size,
keep_prob=FLAGS.keep_prob,
checkpoint_dir=FLAGS.checkpoint_dir
embed_size=FLAGS.embed_size,
checkpoint_dir=FLAGS.checkpoint_dir,
)

show_all_variables()

stock_data = StockDataSet(
"GOOG",
input_size=FLAGS.input_size,
num_steps=FLAGS.num_steps,
test_ratio=0.1,
close_price_only=True
stock_data_list = load_sp500(
FLAGS.input_size,
FLAGS.num_steps,
k=FLAGS.stock_count,
target_symbol=FLAGS.stock_symbol,
)
print stock_data.info()

if FLAGS.train:
rnn_model.train(stock_data, FLAGS)
rnn_model.train(stock_data_list, FLAGS)
else:
if not rnn_model.load()[0]:
raise Exception("[!] Train a model first, then run test mode")
Expand Down
184 changes: 138 additions & 46 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,56 @@
import os
import random
import re
import time
import numpy as np
import tensorflow as tf

import matplotlib
import matplotlib.pyplot as plt

from tensorflow.contrib.tensorboard.plugins import projector

matplotlib.rcParams.update({'font.size': 18})


class LstmRNN(object):
def __init__(self, sess,
def __init__(self, sess, stock_count,
lstm_size=128,
num_layers=1,
num_steps=30,
input_size=1,
keep_prob=0.8,
checkpoint_dir="checkpoints"):
embed_size=None,
checkpoint_dir="checkpoints",
plot_dir="images"):
"""
Construct a RNN model using LSTM cell.
Args:
sess:
stock_count:
lstm_size:
num_layers
num_steps:
input_size:
keep_prob:
embed_size
checkpoint_dir
"""
self.sess = sess
self.stock_count = stock_count

self.lstm_size = lstm_size
self.num_layers = num_layers
self.num_steps = num_steps
self.input_size = input_size
self.keep_prob = keep_prob

self.use_embed = (embed_size is not None) and (embed_size > 0)
self.embed_size = embed_size or 0

self.checkpoint_dir = checkpoint_dir
self.plot_dir = plot_dir

self.build_graph()

Expand All @@ -44,9 +62,11 @@ def build_graph(self):
- learning_rate:
"""
# inputs.shape = (number of examples, number of input, dimension of each input).
self.learning_rate = tf.placeholder(tf.float32, None, name="learning_rate")
self.symbols = tf.placeholder(tf.int32, [None,], name='stock_labels') # mapped to an integer.

self.inputs = tf.placeholder(tf.float32, [None, self.num_steps, self.input_size], name="inputs")
self.targets = tf.placeholder(tf.float32, [None, self.input_size], name="targets")
self.learning_rate = tf.placeholder(tf.float32, None, name="learning_rate")

def _create_one_cell():
lstm_cell = tf.contrib.rnn.LSTMCell(self.lstm_size, state_is_tuple=True)
Expand All @@ -65,18 +85,28 @@ def _create_one_cell():
# After transpose, val.get_shape() = (num_steps, batch_size, lstm_size)
val = tf.transpose(val, [1, 0, 2])

with tf.name_scope("output_layer"):
# last.get_shape() = (batch_size, lstm_size)
last = tf.gather(val, int(val.get_shape()[0]) - 1, name="lstm_output")
# last.get_shape() = (batch_size, lstm_size)
last = tf.gather(val, int(val.get_shape()[0]) - 1, name="lstm_output")

ws = tf.Variable(tf.truncated_normal([self.lstm_size, self.input_size]), name="w")
bias = tf.Variable(tf.constant(0.1, shape=[self.input_size]), name="b")
self.pred = tf.matmul(last, ws) + bias
if self.embed_size > 0:
self.embed_matrix = tf.Variable(
tf.random_uniform([self.stock_count, self.embed_size], -1.0, 1.0),
name="embed_matrix"
)
sym_embeds = tf.nn.embedding_lookup(self.embed_matrix, self.symbols)

# After concat, last.get_shape() = (batch_size, lstm_size + embed_size)
last = tf.concat([last, sym_embeds], axis=1, name="lstm_output_with_embed")

self.last_sum = tf.summary.histogram("lstm_output", last)
self.w_sum = tf.summary.histogram("w", ws)
self.b_sum = tf.summary.histogram("b", bias)
self.pred_summ = tf.summary.histogram("pred", self.pred)
ws = tf.Variable(tf.truncated_normal([
self.lstm_size + self.embed_size, self.input_size]), name="w")
bias = tf.Variable(tf.constant(0.1, shape=[self.input_size]), name="b")
self.pred = tf.matmul(self.last, ws) + bias

self.last_sum = tf.summary.histogram("lstm_output", last)
self.w_sum = tf.summary.histogram("w", ws)
self.b_sum = tf.summary.histogram("b", bias)
self.pred_summ = tf.summary.histogram("pred", self.pred)

# self.loss = -tf.reduce_sum(targets * tf.log(tf.clip_by_value(prediction, 1e-10, 1.0)))
self.loss = tf.reduce_mean(tf.square(self.pred - self.targets), name="loss_mse")
Expand All @@ -87,62 +117,106 @@ def _create_one_cell():
self.t_vars = tf.trainable_variables()
self.saver = tf.train.Saver()

def train(self, dataset, config):
def train(self, dataset_list, config):
"""
Args:
dataset (StockDataSet)
dataset_list (<StockDataSet>)
config (tf.app.flags.FLAGS)
"""

assert len(dataset_list) > 0
self.merged_sum = tf.summary.merge_all()

# Set up the logs folder
self.writer = tf.summary.FileWriter(os.path.join("./logs", self.model_name))
self.writer.add_graph(self.sess.graph)

num_batches = int(len(dataset.train_X)) // config.batch_size
if self.use_embed:
# Set up embedding visualization
# Format: tensorflow/tensorboard/plugins/projector/projector_config.proto
projector_config = projector.ProjectorConfig()

# You can add multiple embeddings. Here we add only one.
added_embed = projector_config.embeddings.add()
added_embed.tensor_name = self.embed_matrix.name
# Link this tensor to its metadata file (e.g. labels).
added_embed.metadata_path = os.path.join("logs/metadata.tsv")

# The next line writes a projector_config.pbtxt in the LOG_DIR. TensorBoard will
# read this file during startup.
projector.visualize_embeddings(self.writer, projector_config)

tf.global_variables_initializer().run()

# Merged test data
merged_test_X = []
merged_test_y = []
merged_test_labels = []

for label_, d_ in enumerate(dataset_list):
merged_test_X += list(d_.test_X)
merged_test_y += list(d_.test_y)
merged_test_labels += [[label_]] * len(d_.test_X)

test_data_feed = {
self.learning_rate: 0.0,
self.inputs: merged_test_X,
self.targets: merged_test_y,
self.symbols: merged_test_labels,
}

global_step = 1

num_batches = sum(len(d_.test_X) // config.batch_size for d_ in dataset_list)
random.seed(time.time())

# Select samples for plotting.
sample_labels = range(4)
sample_indices = {}
for l in sample_labels:
sym = dataset_list[l].stock_sym
target_indices = np.array([
i for i, sym_label in enumerate(merged_test_labels)
if sym_label[0] == l])
sample_indices[sym] = target_indices

for epoch in xrange(config.max_epoch):
epoch_step = 1

learning_rate = config.init_learning_rate * (
config.learning_rate_decay ** max(float(epoch + 1 - config.init_epoch), 0.0)
)

tf.global_variables_initializer().run()

test_data_feed = {
self.inputs: dataset.test_X,
self.targets: dataset.test_y,
self.learning_rate: 0.0
}

for batch_X, batch_y in dataset.generate_one_epoch(config.batch_size):
train_data_feed = {
self.inputs: batch_X,
self.targets: batch_y,
self.learning_rate: learning_rate,
}
train_loss, _ = self.sess.run([self.loss, self.optim], train_data_feed)
global_step += 1
epoch_step += 1
for label_, d_ in enumerate(dataset_list):
for batch_X, batch_y in d_.generate_one_epoch(config.batch_size):
batch_labels = np.array([[label_]] * len(batch_X))
train_data_feed = {
self.learning_rate: learning_rate,
self.inputs: batch_X,
self.targets: batch_y,
self.symbols: batch_labels,
}
train_loss, _, train_merged_sum = self.sess.run(
[self.loss, self.optim, self.merged_sum], train_data_feed)
self.writer.add_summary(train_merged_sum, global_step=global_step)
global_step += 1
epoch_step += 1

if np.mod(epoch, 20) == 0:
test_loss, _pred, _merged_sum = self.sess.run(
[self.loss, self.pred, self.merged_sum], test_data_feed)
assert len(_pred) == len(dataset.test_y)
print "Epoch %d [%d/%d][learning rate: %f]: %.6f" % (
test_loss, test_pred = self.sess.run([self.loss, self.pred], test_data_feed)
assert len(test_pred) == len(d_.test_y)

print "Epoch %d [%d/%d] [learning rate: %f]: %.6f" % (
epoch, epoch_step, num_batches, learning_rate, test_loss)
self.writer.add_summary(_merged_sum, global_step=global_step)

if np.mod(global_step, 500) == 2:
self.save(global_step)
# Plot samples
for sample_sym, indices in sample_indices.iteritems():
image_path = os.path.join(self.plot_dir, "{}_test_{:02d}_{:04d}.png".format(
sample_sym, epoch, epoch_step))
sample_preds = test_pred[indices]
sample_truth = merged_test_y[indices]
self.plot_samples(sample_preds, sample_truth, image_path, stock_sym=sample_sym)

print "Final Results:"
final_pred, final_loss = self.sess.run([self.pred, self.loss], test_data_feed)
print final_pred, final_loss

return final_pred

@property
Expand Down Expand Up @@ -175,5 +249,23 @@ def load(self):
print(" [*] Failed to find a checkpoint")
return False, 0

def plot_samples(self):
pass
def plot_samples(self, preds, targets, figname, stock_sym=None):
def _flatten(seq):
return [x for y in seq for x in y]

truths = _flatten(targets)
preds = _flatten(preds)
days = range(len(truths))

plt.figure(figsize=(8, 6))
plt.plot(days, truths, label='truth')
plt.plot(days, preds, label='pred')
plt.legend()
plt.xlabel("day")
plt.ylabel("normalized price")
plt.grid(ls='--')

if stock_sym:
plt.title(stock_sym + " | %d days in test" % len(truths))

plt.savefig(figname, format='png', bbox_inches='tight', transparent=True)

0 comments on commit 6d5020f

Please sign in to comment.