Skip to content

Commit

Permalink
update output shape of RNN ops according to ONNX spec (onnx#923)
Browse files Browse the repository at this point in the history
* update output shape of RNN ops according to ONNX spec

* formatting

* correct test data
  • Loading branch information
liqunfu authored and gramalingam committed May 17, 2018
1 parent a8b3316 commit 321d874
Show file tree
Hide file tree
Showing 17 changed files with 65 additions and 58 deletions.
14 changes: 7 additions & 7 deletions onnx/backend/test/case/node/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ def __init__(self, **params):
for i in required_inputs:
assert i in params, "Missing Required Input: {0}".format(i)

num_directions = params[W].shape[0]
self.num_directions = params[W].shape[0]

if(num_directions == 1):
if(self.num_directions == 1):
for k in params.keys():
params[k] = np.squeeze(params[k], axis=0)

hidden_size = params[R].shape[-1]
batch_size = params[X].shape[0]
self.hidden_size = params[R].shape[-1]
self.batch_size = params[X].shape[0]

b = params[B] if B in params else np.zeros(2 * number_of_gates * hidden_size)
h_0 = params[H_0] if H_0 in params else np.zeros((batch_size, hidden_size))
b = params[B] if B in params else np.zeros(2 * number_of_gates * self.hidden_size)
h_0 = params[H_0] if H_0 in params else np.zeros((self.batch_size, self.hidden_size))
lbr = params[LBR] if LBR in params else 0

self.X = params[X]
Expand Down Expand Up @@ -65,7 +65,7 @@ def step(self):
h_linear = self.g(np.dot(self.X, np.transpose(w_h)) + r * (np.dot(self.H_0, r_h) + r_bh) + w_bh)
h = h_linear if self.LBR else h_default
H = (1 - z) * h + z * self.H_0
return H
return np.reshape(H, (self.num_directions, self.batch_size, self.hidden_size))


class GRU(Base):
Expand Down
18 changes: 9 additions & 9 deletions onnx/backend/test/case/node/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ def __init__(self, **params):
for i in required_inputs:
assert i in params, "Missing Required Input: {0}".format(i)

num_directions = params[W].shape[0]
self.num_directions = params[W].shape[0]

if(num_directions == 1):
if(self.num_directions == 1):
for k in params.keys():
params[k] = np.squeeze(params[k], axis=0)

hidden_size = params[R].shape[-1]
batch_size = params[X].shape[0]
self.hidden_size = params[R].shape[-1]
self.batch_size = params[X].shape[0]

b = params[B] if B in params else np.zeros(2 * number_of_gates * hidden_size)
p = params[P] if P in params else np.zeros(number_of_peepholes * hidden_size)
h_0 = params[H_0] if H_0 in params else np.zeros((batch_size, hidden_size))
c_0 = params[C_0] if C_0 in params else np.zeros((batch_size, hidden_size))
b = params[B] if B in params else np.zeros(2 * number_of_gates * self.hidden_size)
p = params[P] if P in params else np.zeros(number_of_peepholes * self.hidden_size)
h_0 = params[H_0] if H_0 in params else np.zeros((self.batch_size, self.hidden_size))
c_0 = params[C_0] if C_0 in params else np.zeros((self.batch_size, self.hidden_size))

self.X = params[X]
self.W = params[W]
Expand Down Expand Up @@ -72,7 +72,7 @@ def step(self):
C = f * self.C_0 + i * c
o = self.f(np.dot(self.X, np.transpose(w_o)) + np.dot(self.H_0, r_o) + w_bo + r_bo + p_o * C)
H = o * self.h(C)
return H
return np.reshape(H, (self.num_directions, self.batch_size, self.hidden_size))


class LSTM(Base):
Expand Down
14 changes: 7 additions & 7 deletions onnx/backend/test/case/node/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ def __init__(self, **params):
for i in required_inputs:
assert i in params, "Missing Required Input: {0}".format(i)

num_directions = params[W].shape[0]
self.num_directions = params[W].shape[0]

if(num_directions == 1):
if(self.num_directions == 1):
for k in params.keys():
params[k] = np.squeeze(params[k], axis=0)

hidden_size = params[R].shape[-1]
batch_size = params[X].shape[0]
self.hidden_size = params[R].shape[-1]
self.batch_size = params[X].shape[0]

b = params[B] if B in params else np.zeros(2 * hidden_size)
h_0 = params[H_0] if H_0 in params else np.zeros((batch_size, hidden_size))
b = params[B] if B in params else np.zeros(2 * self.hidden_size)
h_0 = params[H_0] if H_0 in params else np.zeros((self.batch_size, self.hidden_size))

self.X = params[X]
self.W = params[W]
Expand All @@ -50,7 +50,7 @@ def step(self):
[w_b, r_b] = np.split(self.B, 2)

H = self.f(np.dot(self.X, np.transpose(self.W)) + np.dot(self.H_0, self.R) + w_b + r_b)
return H
return np.reshape(H, (self.num_directions, self.batch_size, self.hidden_size))


class RNN(Base):
Expand Down
9 changes: 5 additions & 4 deletions onnx/backend/test/data/node/test_gru_defaults/model.onnx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
 backend-test:�
 backend-test:�
%
X
W
Expand All @@ -18,8 +18,9 @@



b
Y

b
Y



B
Original file line number Diff line number Diff line change
@@ -1 +1 @@
BYJ<��=��=��=��=��=yYM>yYM>yYM>yYM>yYM>�L>�L>�L>�L>�L>
BYJ<��=��=��=��=��=yYM>yYM>yYM>yYM>yYM>�L>�L>�L>�L>�L>
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
 backend-test:�
 backend-test:�
(
X
W
Expand All @@ -23,8 +23,9 @@
B


b
Y

b
Y



B
Original file line number Diff line number Diff line change
@@ -1 +1 @@
BYJ$yYM>yYM>yYM>�>�>�>)G�=)G�=)G�=
BYJ$yYM>yYM>yYM>�>�>�>)G�=)G�=)G�=
9 changes: 5 additions & 4 deletions onnx/backend/test/data/node/test_lstm_defaults/model.onnx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
 backend-test:�
 backend-test:�
&
X
W
Expand All @@ -18,8 +18,9 @@



b
Y
 
b
Y
 


B
Original file line number Diff line number Diff line change
@@ -1 +1 @@
 BYJH����a�?����a�?����a�?�a{)\c�?�a{)\c�?�a{)\c�?�������?�������?�������?
 BYJH����a�?����a�?����a�?�a{)\c�?�a{)\c�?�a{)\c�?�������?�������?�������?
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
 backend-test:�
 backend-test:�
)
X
W
Expand All @@ -23,8 +23,9 @@
B


 b
Y
 
 b
Y
 


B
Original file line number Diff line number Diff line change
@@ -1 +1 @@
 BYJ`��\c�?��\c�?��\c�?��\c�?Zp��,�?Zp��,�?Zp��,�?Zp��,�?Z?y��Y�?Z?y��Y�?Z?y��Y�?Z?y��Y�?
 BYJ`��\c�?��\c�?��\c�?��\c�?Zp��,�?Zp��,�?Zp��,�?Zp��,�?Z?y��Y�?Z?y��Y�?Z?y��Y�?Z?y��Y�?
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
 backend-test:�
 backend-test:�
Q
X
W
Expand Down Expand Up @@ -45,8 +45,9 @@ Q
P


 b
Y

 b
Y



B
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
BYJ �> �> �>.?.?.?
BYJ �> �> �>.?.?.?
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
 backend-test:�
 backend-test:�
%
X
W
Expand All @@ -18,8 +18,9 @@



b
Y

b
Y



B
Original file line number Diff line number Diff line change
@@ -1 +1 @@
BYJ0�&�>�&�>�&�>�&�>ٷ?ٷ?ٷ?ٷ?��L?��L?��L?��L?
BYJ0�&�>�&�>�&�>�&�>ٷ?ٷ?ٷ?ٷ?��L?��L?��L?��L?
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
 backend-test:�
 backend-test:�
(
X
W
Expand All @@ -24,8 +24,9 @@



b
Y

b
Y



B
Original file line number Diff line number Diff line change
@@ -1 +1 @@
BYJ<ٷ?ٷ?ٷ?ٷ?ٷ?x�k?x�k?x�k?x�k?x�k?��|?��|?��|?��|?��|?
BYJ<ٷ?ٷ?ٷ?ٷ?ٷ?x�k?x�k?x�k?x�k?x�k?��|?��|?��|?��|?��|?

0 comments on commit 321d874

Please sign in to comment.