forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
embedding_forward_split_cpu.cpp
586 lines (547 loc) · 20.9 KB
/
embedding_forward_split_cpu.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "codegen/embedding_forward_split_cpu.h"
#include "fbgemm/FbgemmEmbedding.h"
#include "fbgemm/Types.h"
#include "fbgemm/Utils.h"
#include "fbgemm_gpu/cpu_utils.h"
#include "fbgemm_gpu/embedding_common.h"
#ifdef FBCODE_CAFFE2
#include <libdivide.h>
#include "folly/container/F14Map.h"
#else
#include <omp.h>
#endif
#include <ATen/AccumulateType.h>
using Tensor = at::Tensor;
template <typename weights_t, typename ind_weights_t, typename output_t>
void split_embedding_forward_cpu_kernel(
Tensor weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
Tensor hash_size_cumsum,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
Tensor output) {
int64_t T = D_offsets.numel() - 1;
TORCH_CHECK(T > 0);
// offsets = [T x B + 1]
int64_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B >= 0);
TORCH_CHECK(weights.is_contiguous());
indices = indices.contiguous();
offsets = offsets.contiguous();
if (indice_weights.defined()) {
indice_weights = indice_weights.contiguous();
}
const auto D_offsets_data = D_offsets.accessor<int, 1>();
const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>();
const auto indices_data = indices.data_ptr<int64_t>();
const auto offsets_data = offsets.data_ptr<int64_t>();
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
const auto weights_data = weights.data_ptr<weights_t>();
// If indice_weights not defined, then this accessor won't be used.
// The else condition is just to make compiler happy
const auto indice_weights_data = indice_weights.defined()
? indice_weights.data_ptr<ind_weights_t>()
: nullptr;
auto output_data = output.data_ptr<output_t>();
auto output_stride = output.size(1);
constexpr bool use_fbgemm = (std::is_same<weights_t, float>::value ||
std::is_same<weights_t, at::Half>::value ||
std::is_same<weights_t, uint8_t>::value) &&
std::is_same<output_t, float>::value &&
std::is_same<ind_weights_t, float>::value;
at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) {
for (int t = 0; t < T; ++t) {
const auto D_begin = D_offsets_data[t];
const auto D = D_offsets_data[t + 1] - D_offsets_data[t];
const auto table_begin = weights_offsets_data[t];
int64_t hash_size;
int t_temp = t + 1;
do {
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
++t_temp;
} while (hash_size == 0);
bool success = true;
if (use_fbgemm) {
using fbgemm_weight_t = typename std::conditional<
std::is_same<weights_t, at::Half>::value,
fbgemm::float16,
weights_t>::type;
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<
fbgemm_weight_t,
/*IndexType=*/int64_t,
/*OffsetType=*/int64_t>(
D,
indice_weights.defined(),
static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN,
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
output_stride);
auto offsets_begin_ptr = offsets_data + t * B + b_begin;
auto indices_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
success = kernel(
b_end - b_begin,
indices_size,
hash_size,
reinterpret_cast<const fbgemm_weight_t*>(
weights_data + table_begin),
indices_data + *offsets_begin_ptr,
offsets_begin_ptr,
indice_weights.defined()
? reinterpret_cast<const float*>(
indice_weights_data + *offsets_begin_ptr)
: nullptr,
reinterpret_cast<float*>(
output_data + b_begin * output_stride + D_begin));
} else {
at::acc_type<output_t, true> output_buf[D];
for (int b = b_begin; b < b_end; ++b) {
const auto pool_begin = offsets_data[t * B + b];
const auto pool_end = offsets_data[t * B + b + 1];
const auto L = pool_end - pool_begin;
memset(output_buf, 0, D * sizeof(at::acc_type<output_t, true>));
for (auto p = pool_begin; p < pool_end; ++p) {
int64_t idx = indices_data[p];
if (idx < 0 || idx >= hash_size) {
success = false;
break;
}
const int64_t embedding_begin = table_begin + idx * D;
for (int64_t d = 0; d < D; ++d) {
output_buf[d] +=
(indice_weights.defined()
? static_cast<at::acc_type<output_t, true>>(
weights_data[embedding_begin + d]) *
static_cast<at::acc_type<output_t, true>>(
indice_weights_data[p])
: static_cast<at::acc_type<output_t, true>>(
weights_data[embedding_begin + d]));
}
}
const double scale_factor =
// NOTE: MEAN pooling will not work with indice_weights!
(static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN &&
!indice_weights.defined() && L > 0)
? 1.0 / L
: 1.0;
for (int d = 0; d < D; ++d) {
output_data[b * output_stride + D_begin + d] =
scale_factor * output_buf[d];
}
if (!success) {
break;
}
} // for each b
} // !use_fbgemm
if (!success) {
fbgemm_gpu::report_embedding_error(
t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
} // !success
} // for each t
}); // parallel for
}
Tensor split_embedding_codegen_forward_cpu(
Tensor weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
Tensor hash_size_cumsum,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
int64_t output_dtype) {
int64_t T = D_offsets.numel() - 1;
TORCH_CHECK(T > 0);
// offsets = [T x B + 1]
int64_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B >= 0);
Tensor output;
if (output_dtype == static_cast<int64_t>(SparseType::FP32)) {
output = at::empty({B, total_D}, weights.options().dtype(at::kFloat));
} else if (output_dtype == static_cast<int64_t>(SparseType::FP16)) {
output = at::empty({B, total_D}, weights.options().dtype(at::kHalf));
} else if (output_dtype == static_cast<int64_t>(SparseType::BF16)) {
output = at::empty({B, total_D}, weights.options().dtype(at::kBFloat16));
} else {
output = at::empty({B, total_D}, weights.options());
}
// It is assumed that the indice_weights will always be float
TORCH_CHECK(
!indice_weights.defined() || indice_weights.scalar_type() != at::kHalf);
AT_DISPATCH_FLOATING_TYPES(
output.scalar_type(), "split_embedding_cpu_forward", [&] {
using output_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::Byte,
weights.scalar_type(),
"split_embedding_cpu_forward",
[&] {
using ind_weights_t = std::conditional<
std::is_same<scalar_t, double>::value,
double,
float>::type;
split_embedding_forward_cpu_kernel<
scalar_t,
ind_weights_t,
output_t>(
weights,
weights_offsets,
D_offsets,
total_D,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
output);
});
});
return output;
}
template <typename weights_t, typename grad_t>
void split_embedding_grad_indice_weights_cpu_kernel(
Tensor grad_output,
Tensor weights,
Tensor weights_offsets,
Tensor D_offsets,
Tensor indices,
Tensor offsets,
Tensor feature_requires_grad,
Tensor grad_indice_weights) {
int64_t T = D_offsets.numel() - 1;
TORCH_CHECK(T > 0);
// offsets = [T x B + 1]
int64_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B >= 0);
const auto D_offsets_data = D_offsets.accessor<int, 1>();
const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>();
const auto offsets_data = offsets.accessor<int64_t, 1>();
const auto indices_data = indices.accessor<int64_t, 1>();
const auto weights_data = weights.accessor<weights_t, 1>();
const auto grad_output_data = grad_output.accessor<grad_t, 2>();
auto grad_indice_weights_data =
grad_indice_weights.accessor<at::acc_type<grad_t, true>, 1>();
at::parallel_for(0, B, 0, [&](int64_t b_begin, int64_t b_end) {
for (int64_t t = 0; t < T; ++t) {
if (feature_requires_grad.defined() &&
!feature_requires_grad[t].is_nonzero()) {
// NOTE: skip if the table does not require gradient computation!
continue;
}
const auto D_begin = D_offsets_data[t];
const auto D = D_offsets_data[t + 1] - D_offsets_data[t];
const auto table_begin = weights_offsets_data[t];
for (int64_t b = b_begin; b < b_end; ++b) {
const auto pool_begin = offsets_data[t * B + b];
const auto pool_end = offsets_data[t * B + b + 1];
for (auto p = pool_begin; p < pool_end; ++p) {
const int64_t embedding_begin = table_begin + indices_data[p] * D;
for (int64_t d = 0; d < D; ++d) {
grad_indice_weights_data[p] +=
static_cast<at::acc_type<weights_t, true>>(
grad_output_data[b][D_begin + d]) *
weights_data[embedding_begin + d];
}
}
}
} // for each t
}); // parallel for
}
Tensor split_embedding_codegen_grad_indice_weights_cpu(
Tensor grad_output,
Tensor weights,
Tensor weights_offsets,
Tensor D_offsets,
Tensor indices,
Tensor offsets,
Tensor feature_requires_grad) {
auto grad_indice_weights = zeros_like(
indices,
indices.options().dtype(
at::toAccumulateType(grad_output.scalar_type(), true)));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(),
"split_embedding_grad_indice_weights_cpu_outer",
[&] {
using grad_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weights.scalar_type(),
"split_embedding_grad_indice_weights_cpu",
[&] {
using weights_t = scalar_t;
split_embedding_grad_indice_weights_cpu_kernel<weights_t, grad_t>(
grad_output,
weights,
weights_offsets,
D_offsets,
indices,
offsets,
feature_requires_grad,
grad_indice_weights);
});
});
return grad_indice_weights;
}
namespace internal {
template <typename scalar_t>
void batched_csr2csc(
BatchedHyperCompressedSparseColumn& batched_csc,
int B,
// TODO: use accessor for the following 3 parameters
const at::TensorAccessor<int64_t, 1>& batched_csr_offsets,
const at::TensorAccessor<int64_t, 1>& batched_csr_indices,
const at::TensorAccessor<scalar_t, 1>& batched_csr_weights,
int64_t pooling_mode,
const int* table_to_feature_offset,
int64_t num_embeddings) {
int num_tables = 1;
batched_csc.num_tables = num_tables;
batched_csc.table_ptr = static_cast<int*>(
fbgemm::fbgemmAlignedAlloc(64, (num_tables + 1) * sizeof(int)));
batched_csc.table_ptr[0] = 0;
int64_t nnz = batched_csr_offsets[table_to_feature_offset[num_tables] * B] -
batched_csr_offsets[table_to_feature_offset[0] * B];
if (nnz == 0) {
batched_csc.table_ptr[1] = 0;
return;
}
batched_csc.row_indices =
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int)));
bool has_weights = batched_csr_weights.data() != nullptr;
if (has_weights ||
static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN) {
batched_csc.weights = static_cast<float*>(
fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(float)));
}
int column_ptr_curr = 0;
int t = 0;
bool is_shared_table =
table_to_feature_offset[t + 1] > table_to_feature_offset[t] + 1;
auto NS = batched_csr_offsets[table_to_feature_offset[t + 1] * B] -
batched_csr_offsets[table_to_feature_offset[t] * B];
int num_non_empty_segments = 0;
if (!batched_csc.weights) {
batched_csc.column_segment_ids =
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int)));
int* tmpBufKeys =
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
int* tmpBufValues =
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
int* tmpBuf1Keys =
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
int* tmpBuf1Values =
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
const auto FBo = batched_csr_offsets[table_to_feature_offset[t] * B];
for (int feature = table_to_feature_offset[t];
feature < table_to_feature_offset[t + 1];
++feature) {
const auto FBs = (feature - table_to_feature_offset[t]) * B;
#pragma omp parallel for
for (int b = 0; b < B; ++b) {
const auto FBb = feature * B + b;
int64_t pool_begin = batched_csr_offsets[FBb];
int64_t pool_end = batched_csr_offsets[FBb + 1];
for (int64_t p = pool_begin; p < pool_end; ++p) {
tmpBufKeys[p - FBo] = batched_csr_indices[p];
tmpBufValues[p - FBo] = FBs + b;
}
}
}
int* sorted_col_row_index_keys = nullptr;
int* sorted_col_row_index_values = nullptr;
std::tie(sorted_col_row_index_keys, sorted_col_row_index_values) =
fbgemm_gpu::radix_sort_parallel(
tmpBufKeys,
tmpBufValues,
tmpBuf1Keys,
tmpBuf1Values,
NS,
num_embeddings);
int max_thds = omp_get_max_threads();
int num_uniq[max_thds][64];
int U = 0;
if (at::get_num_threads() > 1) {
// This block is not needed for single thread
#pragma omp parallel
{
int tid = omp_get_thread_num();
num_uniq[tid][0] = 0;
#pragma omp for schedule(static)
for (int i = 1; i < NS; i++) {
if (sorted_col_row_index_keys[i] !=
sorted_col_row_index_keys[i - 1]) {
num_uniq[tid][0]++;
}
}
}
num_uniq[0][0] += 1;
for (int i = 1; i < max_thds; i++)
num_uniq[i][0] += num_uniq[i - 1][0];
U = num_uniq[max_thds - 1][0];
}
batched_csc.column_segment_ptr = static_cast<int*>(
fbgemm::fbgemmAlignedAlloc(64, (NS + 1) * sizeof(int)));
batched_csc.column_segment_indices =
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
batched_csc.column_segment_ptr[0] = 0;
batched_csc.row_indices[0] = sorted_col_row_index_values[0] % B;
batched_csc.column_segment_indices[0] = sorted_col_row_index_keys[0];
batched_csc.column_segment_ids[0] = sorted_col_row_index_values[0] / B;
#pragma omp parallel
{
int tid = omp_get_thread_num();
int* tstart =
(tid == 0
? batched_csc.column_segment_indices + 1
: batched_csc.column_segment_indices + num_uniq[tid - 1][0]);
int* t_offs =
(tid == 0 ? batched_csc.column_segment_ptr + 1
: batched_csc.column_segment_ptr + num_uniq[tid - 1][0]);
if (!is_shared_table) {
// For non shared table, no need for computing modulo.
// As an optimization, pointer swap instead of copying.
#pragma omp master
std::swap(
batched_csc.row_indices,
sorted_col_row_index_values == tmpBufValues ? tmpBufValues
: tmpBuf1Values);
} else {
#ifdef FBCODE_CAFFE2
libdivide::divider<int> divisor(B);
#endif
#pragma omp for schedule(static)
for (int i = 1; i < NS; ++i) {
int v = sorted_col_row_index_values[i];
#ifdef FBCODE_CAFFE2
int q = v / divisor;
#else
int q = v / B;
#endif
batched_csc.column_segment_ids[i] = q;
batched_csc.row_indices[i] = v - q * B;
}
}
#pragma omp for schedule(static)
for (int i = 1; i < NS; ++i) {
if (sorted_col_row_index_keys[i] != sorted_col_row_index_keys[i - 1]) {
*tstart = sorted_col_row_index_keys[i];
*t_offs = i;
tstart++;
t_offs++;
}
}
if (at::get_num_threads() == 1 && tid == 0) {
// Special handling of single thread case
U = t_offs - batched_csc.column_segment_ptr;
}
} // omp parallel
batched_csc.table_ptr[t + 1] = batched_csc.table_ptr[t] + U;
batched_csc.column_segment_ptr[U] = NS;
column_ptr_curr += NS;
fbgemm::fbgemmAlignedFree(tmpBufKeys);
fbgemm::fbgemmAlignedFree(tmpBufValues);
fbgemm::fbgemmAlignedFree(tmpBuf1Keys);
fbgemm::fbgemmAlignedFree(tmpBuf1Values);
} else {
// batched_csc.weights
#ifdef FBCODE_CAFFE2
folly::F14FastMap<
#else
std::unordered_map<
#endif
int64_t,
std::vector<std::vector<std::pair<int, scalar_t>>>>
non_empty_columns;
int f_begin = table_to_feature_offset[t];
int f_end = table_to_feature_offset[t + 1];
for (int feature = f_begin; feature < f_end; ++feature) {
for (int b = 0; b < B; ++b) {
int64_t pool_begin = batched_csr_offsets[feature * B + b];
int64_t pool_end = batched_csr_offsets[feature * B + b + 1];
int64_t L = pool_end - pool_begin;
// MEAN pooling will not work with indice_weights!
double scale_factor =
(static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN &&
!has_weights && L > 0)
? 1.0 / L
: 1.0;
for (int64_t p = pool_begin; p < pool_end; ++p) {
auto itr = non_empty_columns.find(batched_csr_indices[p]);
if (itr == non_empty_columns.end()) {
itr = non_empty_columns
.emplace(
batched_csr_indices[p],
std::vector<std::vector<std::pair<int, scalar_t>>>(
f_end - f_begin))
.first;
}
if (itr->second[feature - f_begin].empty()) {
++num_non_empty_segments;
}
itr->second[feature - f_begin].emplace_back(
b, scale_factor * (has_weights ? batched_csr_weights[p] : 1.0f));
}
}
} // for each feature
batched_csc.table_ptr[t + 1] =
batched_csc.table_ptr[t] + num_non_empty_segments;
batched_csc.column_segment_ptr = static_cast<int*>(
fbgemm::fbgemmAlignedAlloc(64, (NS + 1) * sizeof(int)));
batched_csc.column_segment_ptr[0] = 0;
batched_csc.column_segment_indices =
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
batched_csc.column_segment_ids =
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
int k = 1;
for (auto const& column : non_empty_columns) {
int feature = f_begin;
for (auto const& column_segment : column.second) {
if (!column_segment.empty()) {
batched_csc.column_segment_ptr[k] =
column_ptr_curr + column_segment.size();
batched_csc.column_segment_indices[k - 1] = column.first;
batched_csc.column_segment_ids[k - 1] = feature - f_begin;
k++;
for (auto const& non_zero : column_segment) {
batched_csc.row_indices[column_ptr_curr] = non_zero.first;
batched_csc.weights[column_ptr_curr] = non_zero.second;
++column_ptr_curr;
}
}
++feature;
} // for each column segment
} // for each column
} // !batched_csc.weights.empty()
assert(column_ptr_curr == nnz);
}
template void batched_csr2csc<float>(
BatchedHyperCompressedSparseColumn& batched_csc,
int B,
const at::TensorAccessor<int64_t, 1>& batched_csr_offsets,
const at::TensorAccessor<int64_t, 1>& batched_csr_indices,
const at::TensorAccessor<float, 1>& batched_csr_weights,
int64_t pooling_mode,
const int* table_to_feature_offset,
int64_t num_embeddings);
template void batched_csr2csc<double>(
BatchedHyperCompressedSparseColumn& batched_csc,
int B,
const at::TensorAccessor<int64_t, 1>& batched_csr_offsets,
const at::TensorAccessor<int64_t, 1>& batched_csr_indices,
const at::TensorAccessor<double, 1>& batched_csr_weights,
int64_t pooling_mode,
const int* table_to_feature_offset,
int64_t num_embeddings);
} // namespace internal