Skip to content

Commit

Permalink
[CPU][Sampling][Performance] Improve sampling on the CPU. (dmlc#3274)
Browse files Browse the repository at this point in the history
* Optimize sampling

* Stop initialization of array

* Fix includes for linting

* Move comment

* Fix replace

Co-authored-by: Da Zheng <[email protected]>
  • Loading branch information
nv-dlasalle and zheng-da authored Aug 31, 2021
1 parent a53783c commit 8e525da
Showing 1 changed file with 83 additions and 50 deletions.
133 changes: 83 additions & 50 deletions src/array/cpu/rowwise_pick.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
#define DGL_ARRAY_CPU_ROWWISE_PICK_H_

#include <dgl/array.h>
#include <omp.h>
#include <functional>
#include <algorithm>
#include <string>
#include <vector>
#include <memory>

namespace dgl {
namespace aten {
Expand Down Expand Up @@ -92,68 +94,99 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
//
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more
// significant. (minjie)
IdArray picked_row = Full(-1, num_rows * num_picks, sizeof(IdxType) * 8, ctx);
IdArray picked_col = Full(-1, num_rows * num_picks, sizeof(IdxType) * 8, ctx);
IdArray picked_idx = Full(-1, num_rows * num_picks, sizeof(IdxType) * 8, ctx);
IdArray picked_row = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
ctx);
IdArray picked_col = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
ctx);
IdArray picked_idx = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
ctx);
IdxType* picked_rdata = static_cast<IdxType*>(picked_row->data);
IdxType* picked_cdata = static_cast<IdxType*>(picked_col->data);
IdxType* picked_idata = static_cast<IdxType*>(picked_idx->data);

bool all_has_fanout = true;
#pragma omp parallel for reduction(&&:all_has_fanout)
for (int64_t i = 0; i < num_rows; ++i) {
const IdxType rid = rows_data[i];
const IdxType len = indptr[rid + 1] - indptr[rid];
// If a node has no neighbor then all_has_fanout must be false even if replace is
// true.
all_has_fanout = all_has_fanout && (len >= (replace ? 1 : num_picks));
}
const int num_threads = omp_get_max_threads();
std::vector<int64_t> global_prefix(num_threads+1, 0);

#pragma omp parallel for
for (int64_t i = 0; i < num_rows; ++i) {
const IdxType rid = rows_data[i];
CHECK_LT(rid, mat.num_rows);
const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off;
if (len == 0)
continue;
#pragma omp parallel num_threads(num_threads)
{
const int thread_id = omp_get_thread_num();

if (len <= num_picks && !replace) {
// nnz <= num_picks and w/o replacement, take all nnz
for (int64_t j = 0; j < len; ++j) {
picked_rdata[i * num_picks + j] = rid;
picked_cdata[i * num_picks + j] = indices[off + j];
picked_idata[i * num_picks + j] = data? data[off + j] : off + j;
const int64_t start_i = thread_id * (num_rows/num_threads) +
std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
const int64_t end_i = (thread_id + 1) * (num_rows/num_threads) +
std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
assert(thread_id + 1 < num_threads || end_i == num_rows);

const int64_t num_local = end_i - start_i;

// make sure we don't have to pay initialization cost
std::unique_ptr<int64_t[]> local_prefix(new int64_t[num_local + 1]);
local_prefix[0] = 0;
for (int64_t i = start_i; i < end_i; ++i) {
// build prefix-sum
const int64_t local_i = i-start_i;
const IdxType rid = rows_data[i];
IdxType len;
if (replace) {
len = indptr[rid+1] == indptr[rid] ? 0 : num_picks;
} else {
len = std::min(
static_cast<IdxType>(num_picks), indptr[rid + 1] - indptr[rid]);
}
} else {
pick_fn(rid, off, len,
indices, data,
picked_idata + i * num_picks);
for (int64_t j = 0; j < num_picks; ++j) {
const IdxType picked = picked_idata[i * num_picks + j];
picked_rdata[i * num_picks + j] = rid;
picked_cdata[i * num_picks + j] = indices[picked];
picked_idata[i * num_picks + j] = data? data[picked] : picked;
local_prefix[local_i + 1] = local_prefix[local_i] + len;
}
global_prefix[thread_id + 1] = local_prefix[num_local];

#pragma omp barrier
#pragma omp master
{
for (int t = 0; t < num_threads; ++t) {
global_prefix[t+1] += global_prefix[t];
}
}
}
#pragma omp barrier
const IdxType thread_offset = global_prefix[thread_id];

for (int64_t i = start_i; i < end_i; ++i) {
const IdxType rid = rows_data[i];

const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off;
if (len == 0)
continue;

if (!all_has_fanout) {
// correct the array by remove_if
IdxType* new_row_end = std::remove_if(picked_rdata, picked_rdata + num_rows * num_picks,
[] (IdxType i) { return i == -1; });
IdxType* new_col_end = std::remove_if(picked_cdata, picked_cdata + num_rows * num_picks,
[] (IdxType i) { return i == -1; });
IdxType* new_idx_end = std::remove_if(picked_idata, picked_idata + num_rows * num_picks,
[] (IdxType i) { return i == -1; });
const int64_t new_len = (new_row_end - picked_rdata);
CHECK_EQ(new_col_end - picked_cdata, new_len);
CHECK_EQ(new_idx_end - picked_idata, new_len);
picked_row = picked_row.CreateView({new_len}, picked_row->dtype);
picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);
const int64_t local_i = i - start_i;
const int64_t row_offset = thread_offset + local_prefix[local_i];

if (len <= num_picks && !replace) {
// nnz <= num_picks and w/o replacement, take all nnz
for (int64_t j = 0; j < len; ++j) {
picked_rdata[row_offset + j] = rid;
picked_cdata[row_offset + j] = indices[off + j];
picked_idata[row_offset + j] = data? data[off + j] : off + j;
}
} else {
pick_fn(rid, off, len,
indices, data,
picked_idata + row_offset);
for (int64_t j = 0; j < num_picks; ++j) {
const IdxType picked = picked_idata[row_offset + j];
picked_rdata[row_offset + j] = rid;
picked_cdata[row_offset + j] = indices[picked];
picked_idata[row_offset + j] = data? data[picked] : picked;
}
}
}
}

const int64_t new_len = global_prefix.back();
picked_row = picked_row.CreateView({new_len}, picked_row->dtype);
picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);

return COOMatrix(mat.num_rows, mat.num_cols,
picked_row, picked_col, picked_idx);
}
Expand Down

0 comments on commit 8e525da

Please sign in to comment.