Skip to content

Commit

Permalink
Add a time sequence prediction example (pytorch#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
fuzihaofzh authored and soumith committed Apr 5, 2017
1 parent ac5b745 commit 7c57e52
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 0 deletions.
13 changes: 13 additions & 0 deletions time_sequence_prediction/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Time Sequence Prediction
This is a toy example for beginners to start with. It is helpful for learning both pytorch and time sequence prediction. Two LSTMCell units are used in this example to learn some sine wave signals starting at different phases. After learning the sine waves, the network tries to predict the signal values in the future. The results is shown in the picture below.

## Usage

```
python generate_sine_wave.py
python train.py
```

## Result
The initial signal and the predicted results are shown in the image. We first give some initial signals (full line). The network will subsequently give some predicted results (dash line). It can be concluded that the network can generate new sine waves.
![image](https://cloud.githubusercontent.com/assets/1419566/24184438/e24f5280-0f08-11e7-8f8b-4d972b527a81.png)
12 changes: 12 additions & 0 deletions time_sequence_prediction/generate_sine_wave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import math
import numpy as np
import torch
T = 20
L = 1000
N = 100
np.random.seed(2)
x = np.empty((N, L), 'int64')
x[:] = np.array(range(L)) + np.random.randint(-4*T, 4*T, N).reshape(N, 1)
data = np.sin(x / 1.0 / T).astype('float64')
torch.save(data, open('traindata.pt', 'wb'))

81 changes: 81 additions & 0 deletions time_sequence_prediction/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import print_function
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

class Sequence(nn.Module):
def __init__(self):
super(Sequence, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 1)

def forward(self, input, future = 0):
outputs = []
h_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
c_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
h_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)
c_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)

for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
outputs += [c_t2]
for i in range(future):# if we should predict the future
h_t, c_t = self.lstm1(c_t2, (h_t, c_t))
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
outputs += [c_t2]
outputs = torch.stack(outputs, 1).squeeze(2)
return outputs



if __name__ == '__main__':
# set ramdom seed to 0
np.random.seed(0)
torch.manual_seed(0)
# load data and make training set
data = torch.load(open('traindata.pt'))
input = Variable(torch.from_numpy(data[3:, :-1]), requires_grad=False)
target = Variable(torch.from_numpy(data[3:, 1:]), requires_grad=False)
# build the model
seq = Sequence()
seq.double()
criterion = nn.MSELoss()
# use LBFGS as optimizer since we can load the whole data to train
optimizer = optim.LBFGS(seq.parameters())
#begin to train
for i in range(15):
print('STEP: ', i)
def closure():
optimizer.zero_grad()
out = seq(input)
loss = criterion(out, target)
print('loss:', loss.data.numpy()[0])
loss.backward()
return loss
optimizer.step(closure)
# begin to predict
future = 1000
pred = seq(input[:3], future = future)
y = pred.data.numpy()
# draw the result
plt.figure(figsize=(30,10))
plt.title('Predict future values for time sequences\n(Dashlines are predicted values)', fontsize=30)
plt.xlabel('x', fontsize=20)
plt.ylabel('y', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
def draw(yi, color):
plt.plot(np.arange(input.size(1)), yi[:input.size(1)], color, linewidth = 2.0)
plt.plot(np.arange(input.size(1), input.size(1) + future), yi[input.size(1):], color + ':', linewidth = 2.0)
draw(y[0], 'r')
draw(y[1], 'g')
draw(y[2], 'b')
plt.savefig('predict%d.pdf'%i)
plt.close()

0 comments on commit 7c57e52

Please sign in to comment.