forked from pytorch/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a time sequence prediction example (pytorch#118)
- Loading branch information
1 parent
ac5b745
commit 7c57e52
Showing
3 changed files
with
106 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Time Sequence Prediction | ||
This is a toy example for beginners to start with. It is helpful for learning both pytorch and time sequence prediction. Two LSTMCell units are used in this example to learn some sine wave signals starting at different phases. After learning the sine waves, the network tries to predict the signal values in the future. The results is shown in the picture below. | ||
|
||
## Usage | ||
|
||
``` | ||
python generate_sine_wave.py | ||
python train.py | ||
``` | ||
|
||
## Result | ||
The initial signal and the predicted results are shown in the image. We first give some initial signals (full line). The network will subsequently give some predicted results (dash line). It can be concluded that the network can generate new sine waves. | ||
![image](https://cloud.githubusercontent.com/assets/1419566/24184438/e24f5280-0f08-11e7-8f8b-4d972b527a81.png) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import math | ||
import numpy as np | ||
import torch | ||
T = 20 | ||
L = 1000 | ||
N = 100 | ||
np.random.seed(2) | ||
x = np.empty((N, L), 'int64') | ||
x[:] = np.array(range(L)) + np.random.randint(-4*T, 4*T, N).reshape(N, 1) | ||
data = np.sin(x / 1.0 / T).astype('float64') | ||
torch.save(data, open('traindata.pt', 'wb')) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from __future__ import print_function | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
import torch.optim as optim | ||
import numpy as np | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
|
||
class Sequence(nn.Module): | ||
def __init__(self): | ||
super(Sequence, self).__init__() | ||
self.lstm1 = nn.LSTMCell(1, 51) | ||
self.lstm2 = nn.LSTMCell(51, 1) | ||
|
||
def forward(self, input, future = 0): | ||
outputs = [] | ||
h_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False) | ||
c_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False) | ||
h_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False) | ||
c_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False) | ||
|
||
for i, input_t in enumerate(input.chunk(input.size(1), dim=1)): | ||
h_t, c_t = self.lstm1(input_t, (h_t, c_t)) | ||
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2)) | ||
outputs += [c_t2] | ||
for i in range(future):# if we should predict the future | ||
h_t, c_t = self.lstm1(c_t2, (h_t, c_t)) | ||
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2)) | ||
outputs += [c_t2] | ||
outputs = torch.stack(outputs, 1).squeeze(2) | ||
return outputs | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
# set ramdom seed to 0 | ||
np.random.seed(0) | ||
torch.manual_seed(0) | ||
# load data and make training set | ||
data = torch.load(open('traindata.pt')) | ||
input = Variable(torch.from_numpy(data[3:, :-1]), requires_grad=False) | ||
target = Variable(torch.from_numpy(data[3:, 1:]), requires_grad=False) | ||
# build the model | ||
seq = Sequence() | ||
seq.double() | ||
criterion = nn.MSELoss() | ||
# use LBFGS as optimizer since we can load the whole data to train | ||
optimizer = optim.LBFGS(seq.parameters()) | ||
#begin to train | ||
for i in range(15): | ||
print('STEP: ', i) | ||
def closure(): | ||
optimizer.zero_grad() | ||
out = seq(input) | ||
loss = criterion(out, target) | ||
print('loss:', loss.data.numpy()[0]) | ||
loss.backward() | ||
return loss | ||
optimizer.step(closure) | ||
# begin to predict | ||
future = 1000 | ||
pred = seq(input[:3], future = future) | ||
y = pred.data.numpy() | ||
# draw the result | ||
plt.figure(figsize=(30,10)) | ||
plt.title('Predict future values for time sequences\n(Dashlines are predicted values)', fontsize=30) | ||
plt.xlabel('x', fontsize=20) | ||
plt.ylabel('y', fontsize=20) | ||
plt.xticks(fontsize=20) | ||
plt.yticks(fontsize=20) | ||
def draw(yi, color): | ||
plt.plot(np.arange(input.size(1)), yi[:input.size(1)], color, linewidth = 2.0) | ||
plt.plot(np.arange(input.size(1), input.size(1) + future), yi[input.size(1):], color + ':', linewidth = 2.0) | ||
draw(y[0], 'r') | ||
draw(y[1], 'g') | ||
draw(y[2], 'b') | ||
plt.savefig('predict%d.pdf'%i) | ||
plt.close() | ||
|