forked from pytorch/extension-cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlltm.cpp
90 lines (73 loc) · 2.51 KB
/
lltm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <torch/torch.h>
#include <vector>
// s'(z) = (1 - s(z)) * s(z)
at::Tensor d_sigmoid(at::Tensor z) {
auto s = at::sigmoid(z);
return (1 - s) * s;
}
// tanh'(z) = 1 - tanh^2(z)
at::Tensor d_tanh(at::Tensor z) {
return 1 - z.tanh().pow(2);
}
// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) {
auto e = z.exp();
auto mask = (alpha * (e - 1)) < 0;
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}
std::vector<at::Tensor> lltm_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell) {
auto X = at::cat({old_h, input}, /*dim=*/1);
auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1));
auto gates = gate_weights.chunk(3, /*dim=*/1);
auto input_gate = at::sigmoid(gates[0]);
auto output_gate = at::sigmoid(gates[1]);
auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0);
auto new_cell = old_cell + candidate_cell * input_gate;
auto new_h = at::tanh(new_cell) * output_gate;
return {new_h,
new_cell,
input_gate,
output_gate,
candidate_cell,
X,
gate_weights};
}
std::vector<at::Tensor> lltm_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights) {
auto d_output_gate = at::tanh(new_cell) * grad_h;
auto d_tanh_new_cell = output_gate * grad_h;
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
auto d_old_cell = d_new_cell;
auto d_candidate_cell = input_gate * d_new_cell;
auto d_input_gate = candidate_cell * d_new_cell;
auto gates = gate_weights.chunk(3, /*dim=*/1);
d_input_gate *= d_sigmoid(gates[0]);
d_output_gate *= d_sigmoid(gates[1]);
d_candidate_cell *= d_elu(gates[2]);
auto d_gates =
at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
auto d_weights = d_gates.t().mm(X);
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
auto d_X = d_gates.mm(weights);
const auto state_size = grad_h.size(1);
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &lltm_forward, "LLTM forward");
m.def("backward", &lltm_backward, "LLTM backward");
}