forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgpu_rnn.py
125 lines (110 loc) · 4.71 KB
/
gpu_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np
from jaxlib import xla_client
try:
from .cuda import _rnn as _rnn
for _name, _value in _rnn.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
except ImportError:
_rnn = None
if _rnn:
compute_rnn_workspace_reserve_space_sizes = _rnn.compute_rnn_workspace_reserve_space_sizes
def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *,
input_size: int, hidden_size: int, num_layers: int,
dropout: bool, bidirectional: bool):
"""CuDnn RNN."""
out_dtype = ctx.avals_out[0].dtype
if out_dtype == np.float32:
out_type = ir.F32Type.get()
elif out_dtype == np.float64:
out_type = ir.F64Type.get()
elif out_dtype == np.complex64:
out_type = ir.ComplexType.get(ir.F32Type.get())
elif out_dtype == np.complex128:
out_type = ir.ComplexType.get(ir.F64Type.get())
else:
raise ValueError(f'Unknown output type {out_dtype}')
output_type = ir.RankedTensorType.get(ctx.avals_out[0].shape, out_type)
batch_size = ctx.avals_in[0].shape[0]
max_seq_length = ctx.avals_in[0].shape[1]
workspace_shape = ctx.avals_out[3].shape
reserve_space_shape = ctx.avals_out[4].shape
workspace_type = ir.RankedTensorType.get(workspace_shape, ir.F32Type.get())
reserve_space_type = ir.RankedTensorType.get(reserve_space_shape,
ir.F32Type.get())
opaque = _rnn.build_rnn_descriptor(input_size, hidden_size, num_layers,
batch_size, max_seq_length, dropout,
bidirectional, workspace_shape[0],
reserve_space_shape[0])
i32_type = ir.IntegerType.get_signless(32)
out = hlo.CustomCallOp(
[
ir.TupleType.get_tuple([
output_type, h_0.type, c_0.type, workspace_type,
reserve_space_type
])
],
[input, h_0, c_0, weights, seq_lengths],
call_target_name=ir.StringAttr.get('cudnn_rnn'),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
)
return [
hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(5)
]
def _hlo_zeros_f32(shape):
return hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.zeros(shape, dtype=np.float32), type=ir.F32Type.get())).result
def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, workspace,
reserve_space, seq_lengths, *, input_size: int,
hidden_size: int, num_layers: int, dropout: bool,
bidirectional: bool):
"""CuDnn RNN Backward pass."""
batch_size = ctx.avals_in[3].shape[0]
max_seq_length = ctx.avals_in[3].shape[1]
workspace_shape = ctx.avals_in[8].shape
reserve_space_shape = ctx.avals_in[9].shape
opaque = _rnn.build_rnn_descriptor(input_size, hidden_size, num_layers,
batch_size, max_seq_length, dropout,
bidirectional, workspace_shape[0],
reserve_space_shape[0])
i32_type = ir.IntegerType.get_signless(32)
zeroed_dw = _hlo_zeros_f32(ctx.avals_out[3].shape)
out = hlo.CustomCallOp(
[ir.TupleType.get_tuple([x.type, h0.type, c0.type, w.type])], [
dy, dhn, dcn, x, h0, c0, w, y, workspace, reserve_space, zeroed_dw,
seq_lengths
],
call_target_name=ir.StringAttr.get('cudnn_rnn_bwd'),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
output_operand_aliases=ir.ArrayAttr.get([
hlo.OutputOperandAlias.get(
output_tuple_indices=[3],
operand_index=10,
operand_tuple_indices=[])
]))
return [
hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(4)
]