Skip to content

Commit

Permalink
fix gru, rnn, lstm test cases to match the specification and add some…
Browse files Browse the repository at this point in the history
… cases (onnx#920)

* fix lstm test cases to match the specification and add precomputed case

*  add rnn and update test cases

* add and update gru test cases

* update doc

* remove lstm precomputed test case

* update test cases

* update doc and bug fix

* fix mypy
  • Loading branch information
fumihwh authored and linkerzhang committed May 24, 2018
1 parent 4898c9e commit d2a46da
Show file tree
Hide file tree
Showing 17 changed files with 302 additions and 149 deletions.
100 changes: 81 additions & 19 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2517,9 +2517,8 @@ W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astyp
R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)

gru = GRU_Helper(X=input, W=W, R=R)
output = gru.step().astype(np.float32)

expect(node, inputs=[input, W, R], outputs=[output], name='test_gru_defaults')
_, Y_h = gru.step()
expect(node, inputs=[input, W, R], outputs=[Y_h.astype(np.float32)], name='test_gru_defaults')
```

</details>
Expand Down Expand Up @@ -2553,9 +2552,42 @@ R_B = np.zeros((1, number_of_gates * hidden_size)).astype(np.float32)
B = np.concatenate((W_B, R_B), axis=1)

gru = GRU_Helper(X=input, W=W, R=R, B=B)
output = gru.step().astype(np.float32)
_, Y_h = gru.step()
expect(node, inputs=[input, W, R, B], outputs=[Y_h.astype(np.float32)], name='test_gru_with_initial_bias')
```

</details>


<details>
<summary>seq_length</summary>

```python
input = np.array([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
[[10., 11., 12.], [13., 14., 15.], [16., 17., 18.]]]).astype(np.float32)

input_size = 3
hidden_size = 5
number_of_gates = 3

expect(node, inputs=[input, W, R, B], outputs=[output], name='test_gru_with_initial_bias')
node = onnx.helper.make_node(
'GRU',
inputs=['X', 'W', 'R', 'B'],
outputs=['', 'Y'],
hidden_size=hidden_size
)

W = np.random.randn(1, number_of_gates * hidden_size, input_size).astype(np.float32)
R = np.random.randn(1, number_of_gates * hidden_size, hidden_size).astype(np.float32)

# Adding custom bias
W_B = np.random.randn(1, number_of_gates * hidden_size).astype(np.float32)
R_B = np.random.randn(1, number_of_gates * hidden_size).astype(np.float32)
B = np.concatenate((W_B, R_B), axis=1)

gru = GRU_Helper(X=input, W=W, R=R, B=B)
_, Y_h = gru.step()
expect(node, inputs=[input, W, R, B], outputs=[Y_h.astype(np.float32)], name='test_gru_seq_length')
```

</details>
Expand Down Expand Up @@ -3619,9 +3651,8 @@ W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astyp
R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)

lstm = LSTM_Helper(X=input, W=W, R=R)
output = lstm.step()

expect(node, inputs=[input, W, R], outputs=[output], name='test_lstm_defaults')
_, Y_h = lstm.step()
expect(node, inputs=[input, W, R], outputs=[Y_h.astype(np.float32)], name='test_lstm_defaults')
```

</details>
Expand Down Expand Up @@ -3655,9 +3686,8 @@ R_B = np.zeros((1, number_of_gates * hidden_size)).astype(np.float32)
B = np.concatenate((W_B, R_B), 1)

lstm = LSTM_Helper(X=input, W=W, R=R, B=B)
output = lstm.step()

expect(node, inputs=[input, W, R, B], outputs=[output], name='test_lstm_with_initial_bias')
_, Y_h = lstm.step()
expect(node, inputs=[input, W, R, B], outputs=[Y_h.astype(np.float32)], name='test_lstm_with_initial_bias')
```

</details>
Expand Down Expand Up @@ -3692,9 +3722,9 @@ init_c = np.zeros((1, input.shape[1], hidden_size)).astype(np.float32)
P = weight_scale * np.ones((1, number_of_peepholes * hidden_size)).astype(np.float32)

lstm = LSTM_Helper(X=input, W=W, R=R, B=B, P=P, initial_c=init_c, initial_h=init_h)
output = lstm.step()

expect(node, inputs=[input, W, R, B, seq_lens, init_h, init_c, P], outputs=[output], name='test_lstm_with_peepholes')
_, Y_h = lstm.step()
expect(node, inputs=[input, W, R, B, seq_lens, init_h, init_c, P], outputs=[Y_h.astype(np.float32)],
name='test_lstm_with_peepholes')
```

</details>
Expand Down Expand Up @@ -5642,9 +5672,8 @@ W = weight_scale * np.ones((1, hidden_size, input_size)).astype(np.float32)
R = weight_scale * np.ones((1, hidden_size, hidden_size)).astype(np.float32)

rnn = RNN_Helper(X=input, W=W, R=R)
output = rnn.step().astype(np.float32)

expect(node, inputs=[input, W, R], outputs=[output], name='test_simple_rnn_defaults')
_, Y_h = rnn.step()
expect(node, inputs=[input, W, R], outputs=[Y_h.astype(np.float32)], name='test_simple_rnn_defaults')
```

</details>
Expand Down Expand Up @@ -5677,9 +5706,42 @@ R_B = np.zeros((1, hidden_size)).astype(np.float32)
B = np.concatenate((W_B, R_B), axis=1)

rnn = RNN_Helper(X=input, W=W, R=R, B=B)
output = rnn.step().astype(np.float32)
_, Y_h = rnn.step()
expect(node, inputs=[input, W, R, B], outputs=[Y_h.astype(np.float32)],
name='test_simple_rnn_with_initial_bias')
```

</details>


<details>
<summary>seq_length</summary>

```python
input = np.array([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
[[10., 11., 12.], [13., 14., 15.], [16., 17., 18.]]]).astype(np.float32)

input_size = 3
hidden_size = 5

expect(node, inputs=[input, W, R, B], outputs=[output], name='test_simple_rnn_with_initial_bias')
node = onnx.helper.make_node(
'RNN',
inputs=['X', 'W', 'R', 'B'],
outputs=['', 'Y'],
hidden_size=hidden_size
)

W = np.random.randn(1, hidden_size, input_size).astype(np.float32)
R = np.random.randn(1, hidden_size, hidden_size).astype(np.float32)

# Adding custom bias
W_B = np.random.randn(1, hidden_size).astype(np.float32)
R_B = np.random.randn(1, hidden_size).astype(np.float32)
B = np.concatenate((W_B, R_B), axis=1)

rnn = RNN_Helper(X=input, W=W, R=R, B=B)
_, Y_h = rnn.step()
expect(node, inputs=[input, W, R, B], outputs=[Y_h.astype(np.float32)], name='test_rnn_seq_length')
```

</details>
Expand Down
150 changes: 95 additions & 55 deletions onnx/backend/test/case/node/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import unicode_literals

import numpy as np # type: ignore
from typing import Any
from typing import Any, Tuple

import onnx
from ..base import Base
Expand All @@ -13,7 +13,7 @@

class GRU_Helper():
def __init__(self, **params): # type: (*Any) -> None
#GRU Input Names
# GRU Input Names
X = str('X')
W = str('W')
R = str('R')
Expand All @@ -28,15 +28,16 @@ def __init__(self, **params): # type: (*Any) -> None

self.num_directions = params[W].shape[0]

if(self.num_directions == 1):
if self.num_directions == 1:
for k in params.keys():
params[k] = np.squeeze(params[k], axis=0)
if k != X:
params[k] = np.squeeze(params[k], axis=0)

self.hidden_size = params[R].shape[-1]
self.batch_size = params[X].shape[0]
hidden_size = params[R].shape[-1]
batch_size = params[X].shape[1]

b = params[B] if B in params else np.zeros(2 * number_of_gates * self.hidden_size)
h_0 = params[H_0] if H_0 in params else np.zeros((self.batch_size, self.hidden_size))
b = params[B] if B in params else np.zeros(2 * number_of_gates * hidden_size)
h_0 = params[H_0] if H_0 in params else np.zeros((batch_size, hidden_size))
lbr = params[LBR] if LBR in params else 0

self.X = params[X]
Expand All @@ -55,72 +56,111 @@ def f(self, x): # type: (np.ndarray) -> np.ndarray
def g(self, x): # type: (np.ndarray) -> np.ndarray
return np.tanh(x)

def step(self): # type: () -> np.ndarray
def step(self): # type: () -> Tuple[np.ndarray, np.ndarray]
h_list = []
[w_z, w_r, w_h] = np.split(self.W, 3)
[r_z, r_r, r_h] = np.split(self.R, 3)
[w_bz, w_br, w_bh, r_bz, r_br, r_bh] = np.split(self.B, 6)

z = self.f(np.dot(self.X, np.transpose(w_z)) + np.dot(self.H_0, r_z) + w_bz + r_bz)
r = self.f(np.dot(self.X, np.transpose(w_r)) + np.dot(self.H_0, r_r) + w_br + r_br)
h_default = self.g(np.dot(self.X, np.transpose(w_h)) + np.dot(r * self.H_0, r_h) + w_bh + r_bh)
h_linear = self.g(np.dot(self.X, np.transpose(w_h)) + r * (np.dot(self.H_0, r_h) + r_bh) + w_bh)
h = h_linear if self.LBR else h_default
H = (1 - z) * h + z * self.H_0
return np.reshape(H, (self.num_directions, self.batch_size, self.hidden_size))
gates_w = np.transpose(np.concatenate((w_z, w_r)))
gates_r = np.transpose(np.concatenate((r_z, r_r)))
gates_b = np.add(np.concatenate((w_bz, w_br)), np.concatenate((r_bz, r_br)))

H_t = self.H_0
for x in np.split(self.X, self.X.shape[0], axis=0):
gates = np.dot(x, gates_w) + np.dot(H_t, gates_r) + gates_b
z, r = np.split(gates, 2, -1)
z = self.f(z)
r = self.f(r)
h_default = self.g(np.dot(x, np.transpose(w_h)) + np.dot(r * H_t, r_h) + w_bh + r_bh)
h_linear = self.g(np.dot(x, np.transpose(w_h)) + r * (np.dot(H_t, r_h) + r_bh) + w_bh)
h = h_linear if self.LBR else h_default
H = (1 - z) * h + z * H_t
h_list.append(H)
H_t = H
concatenated = np.concatenate(h_list)
if self.num_directions == 1:
output = np.expand_dims(concatenated, 1)
return output, h_list[-1]


class GRU(Base):

@staticmethod
def export_defaults(): # type: () -> None
input = np.array([[[1., 2.], [3., 4.], [5., 6.]]]).astype(np.float32)

input_size = 2
hidden_size = 5
weight_scale = 0.1
number_of_gates = 3
input = np.array([[[1., 2.], [3., 4.], [5., 6.]]]).astype(np.float32)

node = onnx.helper.make_node(
'GRU',
inputs=['X', 'W', 'R'],
outputs=['', 'Y'],
hidden_size=hidden_size
)
input_size = 2
hidden_size = 5
weight_scale = 0.1
number_of_gates = 3

W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)
node = onnx.helper.make_node(
'GRU',
inputs=['X', 'W', 'R'],
outputs=['', 'Y'],
hidden_size=hidden_size
)

gru = GRU_Helper(X=input, W=W, R=R)
output = gru.step().astype(np.float32)
W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)

expect(node, inputs=[input, W, R], outputs=[output], name='test_gru_defaults')
gru = GRU_Helper(X=input, W=W, R=R)
_, Y_h = gru.step()
expect(node, inputs=[input, W, R], outputs=[Y_h.astype(np.float32)], name='test_gru_defaults')

@staticmethod
def export_initial_bias(): # type: () -> None
input = np.array([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]).astype(np.float32)
input = np.array([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]).astype(np.float32)

input_size = 3
hidden_size = 3
weight_scale = 0.1
custom_bias = 0.1
number_of_gates = 3

node = onnx.helper.make_node(
'GRU',
inputs=['X', 'W', 'R', 'B'],
outputs=['', 'Y'],
hidden_size=hidden_size
)

input_size = 3
hidden_size = 3
weight_scale = 0.1
custom_bias = 0.1
number_of_gates = 3
W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)

node = onnx.helper.make_node(
'GRU',
inputs=['X', 'W', 'R', 'B'],
outputs=['', 'Y'],
hidden_size=hidden_size
)
# Adding custom bias
W_B = custom_bias * np.ones((1, number_of_gates * hidden_size)).astype(np.float32)
R_B = np.zeros((1, number_of_gates * hidden_size)).astype(np.float32)
B = np.concatenate((W_B, R_B), axis=1)

gru = GRU_Helper(X=input, W=W, R=R, B=B)
_, Y_h = gru.step()
expect(node, inputs=[input, W, R, B], outputs=[Y_h.astype(np.float32)], name='test_gru_with_initial_bias')

@staticmethod
def export_seq_length(): # type: () -> None
input = np.array([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
[[10., 11., 12.], [13., 14., 15.], [16., 17., 18.]]]).astype(np.float32)

input_size = 3
hidden_size = 5
number_of_gates = 3

W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32)
node = onnx.helper.make_node(
'GRU',
inputs=['X', 'W', 'R', 'B'],
outputs=['', 'Y'],
hidden_size=hidden_size
)

# Adding custom bias
W_B = custom_bias * np.ones((1, number_of_gates * hidden_size)).astype(np.float32)
R_B = np.zeros((1, number_of_gates * hidden_size)).astype(np.float32)
B = np.concatenate((W_B, R_B), axis=1)
W = np.random.randn(1, number_of_gates * hidden_size, input_size).astype(np.float32)
R = np.random.randn(1, number_of_gates * hidden_size, hidden_size).astype(np.float32)

gru = GRU_Helper(X=input, W=W, R=R, B=B)
output = gru.step().astype(np.float32)
# Adding custom bias
W_B = np.random.randn(1, number_of_gates * hidden_size).astype(np.float32)
R_B = np.random.randn(1, number_of_gates * hidden_size).astype(np.float32)
B = np.concatenate((W_B, R_B), axis=1)

expect(node, inputs=[input, W, R, B], outputs=[output], name='test_gru_with_initial_bias')
gru = GRU_Helper(X=input, W=W, R=R, B=B)
_, Y_h = gru.step()
expect(node, inputs=[input, W, R, B], outputs=[Y_h.astype(np.float32)], name='test_gru_seq_length')
Loading

0 comments on commit d2a46da

Please sign in to comment.