Skip to content

Commit 3c957af

Browse files
authored
Merge pull request PaddlePaddle#14080 from tensor-tang/refine/jit/crf2
Refine/jit/crf decoding
2 parents b3b3292 + 64d5b43 commit 3c957af

File tree

5 files changed

+310
-219
lines changed

5 files changed

+310
-219
lines changed

paddle/fluid/operators/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ op_library(flatten_op DEPS reshape_op)
301301
op_library(sequence_pad_op DEPS sequence_padding)
302302
op_library(unstack_op DEPS stack_op)
303303
op_library(fake_quantize_op DEPS memory)
304+
op_library(crf_decoding_op DEPS jit_kernel)
304305
op_library(fusion_lstm_op DEPS jit_kernel)
305306
if (WITH_GPU)
306307
op_library(conv_op DEPS vol2col depthwise_conv im2col)

paddle/fluid/operators/crf_decoding_op.h

+5-218
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include <limits>
1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/math/jit_kernel.h"
1920
#include "paddle/fluid/operators/math/math_function.h"
2021

2122
namespace paddle {
@@ -69,9 +70,6 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
6970
auto emission_dims = emission_weights.dims();
7071
const size_t seq_len = emission_dims[0];
7172
const size_t tag_num = emission_dims[1];
72-
73-
const size_t state_trans_base_idx = 2;
74-
7573
const T* x = emission_weights.data<T>();
7674
const T* w = transition_weights.data<T>();
7775
int64_t* path = decoded_path->data<int64_t>();
@@ -84,221 +82,10 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
8482
Tensor track;
8583
int* track_value =
8684
track.mutable_data<int>(emission_dims, platform::CPUPlace());
87-
88-
#ifdef __AVX__
89-
// It use the AVX or AVX512 instruction to deal the data as the vector of 8 or
90-
// 16 elements per iteration. Then it can implement the parallel processing.
91-
// Only optimize for float type.
92-
#ifdef __AVX512F__
93-
size_t step_size = 16;
94-
#else
95-
size_t step_size = 8;
96-
#endif
97-
if (std::is_same<T, float>::value && (tag_num >= step_size)) {
98-
size_t steps = tag_num / step_size;
99-
size_t remain = tag_num % step_size;
100-
int last_offset = static_cast<int>(remain) - static_cast<int>(step_size);
101-
102-
// Setup the alpha initial value.
103-
size_t i_offset = 0;
104-
for (size_t i = 0; i <= steps; ++i) {
105-
#ifdef __AVX512F__
106-
// Declare the variable for the content of weights, input and alpha
107-
// values.
108-
__m512 w_content, x_content, alpha_content;
109-
110-
// Load the relevant data into the variables from un-aligned address.
111-
w_content = _mm512_loadu_ps((const float*)(w + i_offset));
112-
x_content = _mm512_loadu_ps((const float*)(x + i_offset));
113-
alpha_content = _mm512_add_ps(w_content, x_content);
114-
115-
// Save the alpha value.
116-
_mm512_storeu_ps(reinterpret_cast<float*>(alpha_value + i_offset),
117-
alpha_content);
118-
#else
119-
// Declare the variable for the content of weights, input and alpha
120-
// values.
121-
__m256 w_content, x_content, alpha_content;
122-
123-
// Load the relevant data into the variables from un-aligned address.
124-
w_content = _mm256_loadu_ps((const float*)(w + i_offset));
125-
x_content = _mm256_loadu_ps((const float*)(x + i_offset));
126-
alpha_content = _mm256_add_ps(w_content, x_content);
127-
128-
// Save the alpha value.
129-
_mm256_storeu_ps(reinterpret_cast<float*>(alpha_value + i_offset),
130-
alpha_content);
131-
#endif
132-
i_offset += step_size;
133-
if (i == steps - 1) {
134-
if (remain > 0) {
135-
i_offset += last_offset;
136-
} else {
137-
break;
138-
}
139-
}
140-
}
141-
142-
// Use the column-major strategy to get the location of maximum score.
143-
size_t seq_offset = 0;
144-
for (size_t k = 1; k < seq_len; ++k) {
145-
size_t j_offset = 0;
146-
for (size_t j = 0; j <= steps; ++j) {
147-
#ifdef __AVX512F__
148-
// Initialize the variables of maximum score and location.
149-
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<T>::max());
150-
__m512i max_j = _mm512_setzero_si512();
151-
#else
152-
// Initialize the variables of maximum score and location.
153-
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<T>::max());
154-
__m256i max_j = _mm256_set1_epi32(0);
155-
#endif
156-
// Calculate the offset of transition_weights.
157-
size_t trans_offset = state_trans_base_idx * tag_num + j_offset;
158-
for (size_t i = 0; i < tag_num; ++i) {
159-
#ifdef __AVX512F__
160-
// Initalize the content of alpha variable with related offset.
161-
__m512 alpha_content =
162-
_mm512_set1_ps(*(const float*)(alpha_value + seq_offset + i));
163-
// Obtain the content of weights from un-aligned address.
164-
__m512 w_content =
165-
_mm512_loadu_ps((const float*)(w + trans_offset));
166-
167-
__m512 score_v = _mm512_add_ps(alpha_content, w_content);
168-
169-
__mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS);
170-
171-
// According to the mask value, it update the index of the max_score
172-
// location.
173-
max_j = _mm512_mask_set1_epi32(max_j, mask, i);
174-
175-
// Update the max_score value.
176-
max_score = _mm512_max_ps(max_score, score_v);
177-
#else
178-
// Initalize the content of alpha variable with related offset.
179-
__m256 alpha_content = _mm256_broadcast_ss(
180-
(const float*)(alpha_value + seq_offset + i));
181-
// Obtain the content of weights from un-aligned address.
182-
__m256 w_content =
183-
_mm256_loadu_ps((const float*)(w + trans_offset));
184-
__m256 score_v = _mm256_add_ps(alpha_content, w_content);
185-
186-
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS);
187-
188-
#ifdef __AVX2__
189-
// According to the mask value, it update the index of the max_score
190-
// location.
191-
max_j = _mm256_or_si256(
192-
_mm256_andnot_si256((__m256i)mask, max_j),
193-
_mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i)));
194-
#else
195-
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0);
196-
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1);
197-
__m128i lo_mask = _mm256_extractf128_si256((__m256i)mask, 0);
198-
__m128i hi_mask = _mm256_extractf128_si256((__m256i)mask, 1);
199-
200-
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j);
201-
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j);
202-
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i));
203-
hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i));
204-
205-
lo_max_j = _mm_or_si128(lo_mask, lo_max_j);
206-
hi_max_j = _mm_or_si128(hi_mask, hi_max_j);
207-
208-
// According to the mask value, it update the index of the max_score
209-
// location.
210-
max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0);
211-
max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1);
212-
#endif
213-
214-
// Update the max_score value.
215-
max_score = _mm256_max_ps(max_score, score_v);
216-
#endif
217-
trans_offset += tag_num;
218-
}
219-
220-
#ifdef __AVX512F__
221-
// Update the alpha and track values.
222-
__m512 x_content = _mm512_loadu_ps(
223-
(const float*)(x + seq_offset + tag_num + j_offset));
224-
max_score = _mm512_add_ps(max_score, x_content);
225-
_mm512_storeu_ps(reinterpret_cast<float*>(alpha_value + seq_offset +
226-
tag_num + j_offset),
227-
max_score);
228-
_mm512_storeu_si512(
229-
reinterpret_cast<__m512i*>(track_value + seq_offset + tag_num +
230-
j_offset),
231-
max_j);
232-
#else
233-
// Update the alpha and track values.
234-
__m256 x_content = _mm256_loadu_ps(
235-
(const float*)(x + seq_offset + tag_num + j_offset));
236-
max_score = _mm256_add_ps(max_score, x_content);
237-
_mm256_storeu_ps(reinterpret_cast<float*>(alpha_value + seq_offset +
238-
tag_num + j_offset),
239-
max_score);
240-
_mm256_storeu_si256(
241-
reinterpret_cast<__m256i*>(track_value + seq_offset + tag_num +
242-
j_offset),
243-
max_j);
244-
#endif
245-
246-
// Calculate the offset of next step
247-
j_offset += step_size;
248-
if (j == steps - 1) {
249-
if (remain > 0) {
250-
j_offset += last_offset;
251-
} else {
252-
break;
253-
}
254-
}
255-
}
256-
257-
seq_offset += tag_num;
258-
}
259-
} else {
260-
for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
261-
262-
for (size_t k = 1; k < seq_len; ++k) {
263-
for (size_t i = 0; i < tag_num; ++i) {
264-
T max_score = -std::numeric_limits<T>::max();
265-
int max_j = 0;
266-
for (size_t j = 0; j < tag_num; ++j) {
267-
T score = alpha_value[(k - 1) * tag_num + j] +
268-
w[(j + state_trans_base_idx) * tag_num + i];
269-
if (score > max_score) {
270-
max_score = score;
271-
max_j = j;
272-
}
273-
}
274-
275-
alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
276-
track_value[k * tag_num + i] = max_j;
277-
}
278-
}
279-
}
280-
#else
281-
for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
282-
283-
for (size_t k = 1; k < seq_len; ++k) {
284-
for (size_t i = 0; i < tag_num; ++i) {
285-
T max_score = -std::numeric_limits<T>::max();
286-
int max_j = 0;
287-
for (size_t j = 0; j < tag_num; ++j) {
288-
T score = alpha_value[(k - 1) * tag_num + j] +
289-
w[(j + state_trans_base_idx) * tag_num + i];
290-
if (score > max_score) {
291-
max_score = score;
292-
max_j = j;
293-
}
294-
}
295-
296-
alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
297-
track_value[k * tag_num + i] = max_j;
298-
}
299-
}
300-
301-
#endif
85+
const auto& ker = math::jitkernel::KernelPool::Instance()
86+
.template Get<math::jitkernel::CRFDecodeKernel<T>>(
87+
static_cast<int>(tag_num));
88+
ker->Compute(static_cast<int>(seq_len), x, w, alpha_value, track_value);
30289
T max_score = -std::numeric_limits<T>::max();
30390
int max_i = 0;
30491
for (size_t i = 0; i < tag_num; ++i) {

paddle/fluid/operators/math/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,6 @@ endif()
7676
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
7777
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
7878
cc_library(jit_kernel
79-
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc
79+
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc
8080
DEPS cpu_info cblas)
8181
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)

paddle/fluid/operators/math/jit_kernel.h

+7
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ class GRUKernel : public Kernel {
151151
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0;
152152
};
153153

154+
template <typename T>
155+
class CRFDecodeKernel : public Kernel {
156+
public:
157+
virtual void Compute(const int seq_len, const T *x, const T *w, T *alpha,
158+
int *track) const = 0;
159+
};
160+
154161
} // namespace jitkernel
155162
} // namespace math
156163
} // namespace operators

0 commit comments

Comments
 (0)