forked from hunkim/PyTorchZeroToAll
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path11_1_rnn_basics.py
46 lines (37 loc) · 1.44 KB
/
11_1_rnn_basics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
from torch.autograd import Variable
# One hot encoding for each char in 'hello'
h = [1, 0, 0, 0]
e = [0, 1, 0, 0]
l = [0, 0, 1, 0]
o = [0, 0, 0, 1]
# One cell RNN input_dim (4) -> output_dim (2). sequence: 5
cell = nn.RNN(input_size=4, hidden_size=2, batch_first=True)
# (num_layers * num_directions, batch, hidden_size)
hidden = (Variable(torch.randn(1, 1, 2)))
# Propagate input through RNN
# Input: (batch, seq_len, input_size) when batch_first=True
inputs = Variable(torch.Tensor([[h, e, l, l, o]]))
print("input size", inputs.size())
for one in inputs[0]:
one = one.view(1, 1, -1)
# Input: (batch, seq_len, input_size) when batch_first=True
out, hidden = cell(one, hidden)
print(out.size())
# We can do the whole at once
# Propagate input through RNN
# Input: (batch, seq_len, input_size) when batch_first=True
out, hidden = cell(inputs, hidden)
print("out size", out.size())
# One cell RNN input_dim (4) -> output_dim (2). sequence: 5, batch 3
# 3 batches 'hello', 'eolll', 'lleel'
# rank = (3, 5, 4)
inputs = Variable(torch.Tensor([[h, e, l, l, o],
[e, o, l, l, l],
[l, l, e, e, l]]))
print("input size", inputs.size()) # input size torch.Size([3, 5, 4])
# Propagate input through RNN
# Input: (batch, seq_len, input_size) when batch_first=True
out, hidden = cell(inputs, hidden)
print("out size", out.size()) # out size torch.Size([3, 5, 2])