Skip to content
/ LSTNet Public
forked from fbadine/LSTNet

A Tensorflow / Keras implementation of "Modeling Long- and Short-Term Temporal Patterns with Deep Neural Networks" paper

Notifications You must be signed in to change notification settings

Sainpse/LSTNet

 
 

Repository files navigation

LSTNet

This repository is a Tensorflow / Keras implementation of Modeling Long- and Short-Term Temporal Patterns with Deep Neural Networks paper https://arxiv.org/pdf/1703.07015.pdf

This implementation has been inspired by the following Pytorch implementation https://github.com/laiguokun/LSTNet

Installation

Clone this prerequisite repository:

git clone https://github.com/fbadine/util.git

Clone this repository:

git clone https://github.com/fbadine/LSTNet.git
cd LSTNet
mkdir log/ save/ data/

Download the dataset from https://github.com/laiguokun/multivariate-time-series-data and copy the text files into LSTNet/data/

Usage

Training

There are 4 different script samples to train, validate and test the model on the different datasets:

  • electricity.sh
  • exchange_rate.sh
  • solar.sh
  • traffic.sh

Predict

In order to predict and plot traffic you will need to run main.py as follows (example for the electricity traffic)

python3.6 main.py --data="data/electricity.txt" --no-train --load="save/electricity/electricity" --predict=all --plot --series-to-plot=0 

Running Options

The following are the parameters that the python script takes along with their description:

Input Parameters Default Description
--data Full Path of the data file. (REQUIRED)
--normalize 2 Type of data normalisation:
- 0: No Normalisation
- 1: Normalise all timeseries together
- 2: Normalise each timeseries alone
--trainpercent 0.6 Percentage of the given data to use for training
--validpercent 0.2 Percentage of the given data to use for validation
--window 24 * 7 Number of time values to consider in each input X
--horizon 12 How far is the predicted value Y. It is horizon values away from the last value of X (into the future)
--CNNFilters 100 Number of output filters in the CNN layer
A value of 0 will remove this layer
--CNNKernel 6 CNN filter size that will be (CNNKernel, number of multivariate timeseries)
A value of 0 will remove this layer
--GRUUnits 100 Number of hidden states in the GRU layer
--SkipGRUUnits 5 Number of hidden states in the SkipGRU layer
--skip 24 Number of timeslots to skip.
A value of 0 will remove this layer
--dropout 0.2 Dropout frequency
--highway 24 Number of timeslots values to consider for the linear layer (AR layer)
--initializer glorot_uniform The weights initialiser to use
--loss mean_absolute_error The loss function to use for optimisation
--optimizer Adam The optimiser to use
Accepted values:
- SGD
- RMSprop
- Adam
--lr 0.001 Learning rate
--batchsize 128 Training batchsize
--epochs 100 Number of training epochs
--tensorboard None Set to the folder where to put the tensorboard file
If set to None => no tensorboard
--no-train Do not train the model
--no-validation Do not validate the model
--test Evaluate the model on the test data
--load None Location and Name of the file to load a pre-trained model from as follows:
- Model in filename.json
- Weights in filename.h5
--save None Full path of the file to save the model in as follows:
- Model in filename.json
- Weights in filename.h5
This location is also used to save results and history as follows:
- Results in filename.txt
- History in filename_history.csv if --savehistory is passed
--no-saveresults Do not save results
--savehistory Save training / validation history in file as described in parameter --save above
--predict None Predict timeseries using the trained model
It takes one of the following values:
- trainingdata: predict the training data only
- validationdata: predict the validation data only
- testingdata: predict the testing data only
- all: all of the above
- None: none of the above
--plot Generate plots
--series-to-plot 0 Series to plot
Format: series,start,end
- series: the number of the series you wish to plot
- start: start timeslot (default is the start of the timeseries)
- end: end timeslot (default is the end of the timeseries)
--autocorrelation None Autocorrelation plotting
Format: series,start,end
- series: the number of random timeseries you wish to plot the autocorrelation for
- start: start timeslot (default is the start of the timeseries)
- end: end timeslot (default is the end of the timeseries)
--save-plot None Location and name of the file to save the plotted images to
- Autocorrelation in filename_autocorrelation.png
- Training history in filename_training.png
- Prediction in filename_prediction.png
--no-log Do not create logfiles
However error and critical messages will still appear
--logfilename log/lstnet Full path of the logging file
--debuglevel 20 Logging debug level

Results

The followinng are the results that were obtained:

Dataset Width Horizon Correlation RSE
Solar 28 hours 2 hours 0.9548 0.3060
Traffic 7 days 12 hours 0.8932 0.4089
Electricity 7 days 24 hours 0.8856 0.3746
Exchange Rate 168 days 12 days 0.9731 0.1540

Dataset

As described in the paper the data is composed of 4 publicly available datasets downloadable from https://github.com/laiguokun/multivariate-time-series-data:

  • Traffic: A collection of 48 months (2015-2016) hourly data from the California Department of Transportation
  • Solar Energy: The solar power production records in 2006, sampled every 10 minutes from 137 PV plants in the state of Alabama
  • Electricity: Electricity consumption for 321 clients recorded every 15 minutes from 2012 to 2014
  • Exchange Rate: A collection of daily average rates of 8 currencies from 1990 to 2016

Environment

Primary environment

The results were obtained on a system with the following versions:

  • Python 3.6.8
  • Tensorflow 1.11.0
  • Keras 2.1.6-tf

TensorFlow 2.0 Ready

The model has also been tested on TF 2.0 alpha version:

  • Python 3.6.7
  • Tensorflow 2.0.0-alpha0
  • Keras 2.2.4-tf

About

A Tensorflow / Keras implementation of "Modeling Long- and Short-Term Temporal Patterns with Deep Neural Networks" paper

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 98.9%
  • Shell 1.1%