@@ -16,6 +16,7 @@ limitations under the License. */
16
16
#include < limits>
17
17
#include " paddle/fluid/framework/eigen.h"
18
18
#include " paddle/fluid/framework/op_registry.h"
19
+ #include " paddle/fluid/operators/math/jit_kernel.h"
19
20
#include " paddle/fluid/operators/math/math_function.h"
20
21
21
22
namespace paddle {
@@ -69,9 +70,6 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
69
70
auto emission_dims = emission_weights.dims ();
70
71
const size_t seq_len = emission_dims[0 ];
71
72
const size_t tag_num = emission_dims[1 ];
72
-
73
- const size_t state_trans_base_idx = 2 ;
74
-
75
73
const T* x = emission_weights.data <T>();
76
74
const T* w = transition_weights.data <T>();
77
75
int64_t * path = decoded_path->data <int64_t >();
@@ -84,221 +82,10 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
84
82
Tensor track;
85
83
int * track_value =
86
84
track.mutable_data <int >(emission_dims, platform::CPUPlace ());
87
-
88
- #ifdef __AVX__
89
- // It use the AVX or AVX512 instruction to deal the data as the vector of 8 or
90
- // 16 elements per iteration. Then it can implement the parallel processing.
91
- // Only optimize for float type.
92
- #ifdef __AVX512F__
93
- size_t step_size = 16 ;
94
- #else
95
- size_t step_size = 8 ;
96
- #endif
97
- if (std::is_same<T, float >::value && (tag_num >= step_size)) {
98
- size_t steps = tag_num / step_size;
99
- size_t remain = tag_num % step_size;
100
- int last_offset = static_cast <int >(remain) - static_cast <int >(step_size);
101
-
102
- // Setup the alpha initial value.
103
- size_t i_offset = 0 ;
104
- for (size_t i = 0 ; i <= steps; ++i) {
105
- #ifdef __AVX512F__
106
- // Declare the variable for the content of weights, input and alpha
107
- // values.
108
- __m512 w_content, x_content, alpha_content;
109
-
110
- // Load the relevant data into the variables from un-aligned address.
111
- w_content = _mm512_loadu_ps ((const float *)(w + i_offset));
112
- x_content = _mm512_loadu_ps ((const float *)(x + i_offset));
113
- alpha_content = _mm512_add_ps (w_content, x_content);
114
-
115
- // Save the alpha value.
116
- _mm512_storeu_ps (reinterpret_cast <float *>(alpha_value + i_offset),
117
- alpha_content);
118
- #else
119
- // Declare the variable for the content of weights, input and alpha
120
- // values.
121
- __m256 w_content, x_content, alpha_content;
122
-
123
- // Load the relevant data into the variables from un-aligned address.
124
- w_content = _mm256_loadu_ps ((const float *)(w + i_offset));
125
- x_content = _mm256_loadu_ps ((const float *)(x + i_offset));
126
- alpha_content = _mm256_add_ps (w_content, x_content);
127
-
128
- // Save the alpha value.
129
- _mm256_storeu_ps (reinterpret_cast <float *>(alpha_value + i_offset),
130
- alpha_content);
131
- #endif
132
- i_offset += step_size;
133
- if (i == steps - 1 ) {
134
- if (remain > 0 ) {
135
- i_offset += last_offset;
136
- } else {
137
- break ;
138
- }
139
- }
140
- }
141
-
142
- // Use the column-major strategy to get the location of maximum score.
143
- size_t seq_offset = 0 ;
144
- for (size_t k = 1 ; k < seq_len; ++k) {
145
- size_t j_offset = 0 ;
146
- for (size_t j = 0 ; j <= steps; ++j) {
147
- #ifdef __AVX512F__
148
- // Initialize the variables of maximum score and location.
149
- __m512 max_score = _mm512_set1_ps (-std::numeric_limits<T>::max ());
150
- __m512i max_j = _mm512_setzero_si512 ();
151
- #else
152
- // Initialize the variables of maximum score and location.
153
- __m256 max_score = _mm256_set1_ps (-std::numeric_limits<T>::max ());
154
- __m256i max_j = _mm256_set1_epi32 (0 );
155
- #endif
156
- // Calculate the offset of transition_weights.
157
- size_t trans_offset = state_trans_base_idx * tag_num + j_offset;
158
- for (size_t i = 0 ; i < tag_num; ++i) {
159
- #ifdef __AVX512F__
160
- // Initalize the content of alpha variable with related offset.
161
- __m512 alpha_content =
162
- _mm512_set1_ps (*(const float *)(alpha_value + seq_offset + i));
163
- // Obtain the content of weights from un-aligned address.
164
- __m512 w_content =
165
- _mm512_loadu_ps ((const float *)(w + trans_offset));
166
-
167
- __m512 score_v = _mm512_add_ps (alpha_content, w_content);
168
-
169
- __mmask16 mask = _mm512_cmp_ps_mask (score_v, max_score, _CMP_GT_OS);
170
-
171
- // According to the mask value, it update the index of the max_score
172
- // location.
173
- max_j = _mm512_mask_set1_epi32 (max_j, mask, i);
174
-
175
- // Update the max_score value.
176
- max_score = _mm512_max_ps (max_score, score_v);
177
- #else
178
- // Initalize the content of alpha variable with related offset.
179
- __m256 alpha_content = _mm256_broadcast_ss (
180
- (const float *)(alpha_value + seq_offset + i));
181
- // Obtain the content of weights from un-aligned address.
182
- __m256 w_content =
183
- _mm256_loadu_ps ((const float *)(w + trans_offset));
184
- __m256 score_v = _mm256_add_ps (alpha_content, w_content);
185
-
186
- __m256 mask = _mm256_cmp_ps (score_v, max_score, _CMP_GT_OS);
187
-
188
- #ifdef __AVX2__
189
- // According to the mask value, it update the index of the max_score
190
- // location.
191
- max_j = _mm256_or_si256 (
192
- _mm256_andnot_si256 ((__m256i)mask, max_j),
193
- _mm256_and_si256 ((__m256i)mask, _mm256_set1_epi32 (i)));
194
- #else
195
- __m128i lo_max_j = _mm256_extractf128_si256 (max_j, 0 );
196
- __m128i hi_max_j = _mm256_extractf128_si256 (max_j, 1 );
197
- __m128i lo_mask = _mm256_extractf128_si256 ((__m256i)mask, 0 );
198
- __m128i hi_mask = _mm256_extractf128_si256 ((__m256i)mask, 1 );
199
-
200
- lo_max_j = _mm_andnot_si128 (lo_mask, lo_max_j);
201
- hi_max_j = _mm_andnot_si128 (hi_mask, hi_max_j);
202
- lo_mask = _mm_and_si128 (lo_mask, _mm_set1_epi32 (i));
203
- hi_mask = _mm_and_si128 (hi_mask, _mm_set1_epi32 (i));
204
-
205
- lo_max_j = _mm_or_si128 (lo_mask, lo_max_j);
206
- hi_max_j = _mm_or_si128 (hi_mask, hi_max_j);
207
-
208
- // According to the mask value, it update the index of the max_score
209
- // location.
210
- max_j = _mm256_insertf128_si256 (max_j, lo_max_j, 0 );
211
- max_j = _mm256_insertf128_si256 (max_j, hi_max_j, 1 );
212
- #endif
213
-
214
- // Update the max_score value.
215
- max_score = _mm256_max_ps (max_score, score_v);
216
- #endif
217
- trans_offset += tag_num;
218
- }
219
-
220
- #ifdef __AVX512F__
221
- // Update the alpha and track values.
222
- __m512 x_content = _mm512_loadu_ps (
223
- (const float *)(x + seq_offset + tag_num + j_offset));
224
- max_score = _mm512_add_ps (max_score, x_content);
225
- _mm512_storeu_ps (reinterpret_cast <float *>(alpha_value + seq_offset +
226
- tag_num + j_offset),
227
- max_score);
228
- _mm512_storeu_si512 (
229
- reinterpret_cast <__m512i*>(track_value + seq_offset + tag_num +
230
- j_offset),
231
- max_j);
232
- #else
233
- // Update the alpha and track values.
234
- __m256 x_content = _mm256_loadu_ps (
235
- (const float *)(x + seq_offset + tag_num + j_offset));
236
- max_score = _mm256_add_ps (max_score, x_content);
237
- _mm256_storeu_ps (reinterpret_cast <float *>(alpha_value + seq_offset +
238
- tag_num + j_offset),
239
- max_score);
240
- _mm256_storeu_si256 (
241
- reinterpret_cast <__m256i*>(track_value + seq_offset + tag_num +
242
- j_offset),
243
- max_j);
244
- #endif
245
-
246
- // Calculate the offset of next step
247
- j_offset += step_size;
248
- if (j == steps - 1 ) {
249
- if (remain > 0 ) {
250
- j_offset += last_offset;
251
- } else {
252
- break ;
253
- }
254
- }
255
- }
256
-
257
- seq_offset += tag_num;
258
- }
259
- } else {
260
- for (size_t i = 0 ; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
261
-
262
- for (size_t k = 1 ; k < seq_len; ++k) {
263
- for (size_t i = 0 ; i < tag_num; ++i) {
264
- T max_score = -std::numeric_limits<T>::max ();
265
- int max_j = 0 ;
266
- for (size_t j = 0 ; j < tag_num; ++j) {
267
- T score = alpha_value[(k - 1 ) * tag_num + j] +
268
- w[(j + state_trans_base_idx) * tag_num + i];
269
- if (score > max_score) {
270
- max_score = score;
271
- max_j = j;
272
- }
273
- }
274
-
275
- alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
276
- track_value[k * tag_num + i] = max_j;
277
- }
278
- }
279
- }
280
- #else
281
- for (size_t i = 0 ; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
282
-
283
- for (size_t k = 1 ; k < seq_len; ++k) {
284
- for (size_t i = 0 ; i < tag_num; ++i) {
285
- T max_score = -std::numeric_limits<T>::max ();
286
- int max_j = 0 ;
287
- for (size_t j = 0 ; j < tag_num; ++j) {
288
- T score = alpha_value[(k - 1 ) * tag_num + j] +
289
- w[(j + state_trans_base_idx) * tag_num + i];
290
- if (score > max_score) {
291
- max_score = score;
292
- max_j = j;
293
- }
294
- }
295
-
296
- alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
297
- track_value[k * tag_num + i] = max_j;
298
- }
299
- }
300
-
301
- #endif
85
+ const auto & ker = math::jitkernel::KernelPool::Instance ()
86
+ .template Get <math::jitkernel::CRFDecodeKernel<T>>(
87
+ static_cast <int >(tag_num));
88
+ ker->Compute (static_cast <int >(seq_len), x, w, alpha_value, track_value);
302
89
T max_score = -std::numeric_limits<T>::max ();
303
90
int max_i = 0 ;
304
91
for (size_t i = 0 ; i < tag_num; ++i) {
0 commit comments