Skip to content

Commit

Permalink
[Kernel] Matrix Union (dmlc#1752)
Browse files Browse the repository at this point in the history
* Matrix union

* Pass test

* Fix lint

* return map for unionCOO/unionCSR

* Revert "return map for unionCOO/unionCSR"

This reverts commit 28e96c4.

* Update

* lint

* lint

* Fix doc

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
classicsong and Ubuntu authored Jul 9, 2020
1 parent 29e6c93 commit 27cad32
Show file tree
Hide file tree
Showing 6 changed files with 715 additions and 9 deletions.
34 changes: 33 additions & 1 deletion include/dgl/aten/coo.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,39 @@ COOMatrix COORowWiseTopk(
bool ascending = false);

/*!
* \brief Union a list COOMatrix into one COOMatrix.
* \brief Union two COOMatrix into one COOMatrix.
*
* Two Matrix must have the same shape.
*
* Example:
*
* A = [[0, 0, 1, 0],
* [1, 0, 1, 1],
* [0, 1, 0, 0]]
*
* B = [[0, 1, 1, 0],
* [0, 0, 0, 1],
* [0, 0, 1, 0]]
*
* COOMatrix_A.num_rows : 3
* COOMatrix_A.num_cols : 4
* COOMatrix_B.num_rows : 3
* COOMatrix_B.num_cols : 4
*
* C = UnionCoo({A, B});
*
* C = [[0, 1, 2, 0],
* [1, 0, 1, 2],
* [0, 1, 1, 0]]
*
* COOMatrix_C.num_rows : 3
* COOMatrix_C.num_cols : 4
*/
COOMatrix UnionCoo(
const std::vector<COOMatrix>& coos);

/*!
* \brief DisjointUnion a list COOMatrix into one COOMatrix.
*
* Examples:
*
Expand Down
32 changes: 32 additions & 0 deletions include/dgl/aten/csr.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,38 @@ COOMatrix CSRRowWiseTopk(
FloatArray weight,
bool ascending = false);

/*!
* \brief Union two CSRMatrix into one CSRMatrix.
*
* Two Matrix must have the same shape.
*
* Example:
*
* A = [[0, 0, 1, 0],
* [1, 0, 1, 1],
* [0, 1, 0, 0]]
*
* B = [[0, 1, 1, 0],
* [0, 0, 0, 1],
* [0, 0, 1, 0]]
*
* CSRMatrix_A.num_rows : 3
* CSRMatrix_A.num_cols : 4
* CSRMatrix_B.num_rows : 3
* CSRMatrix_B.num_cols : 4
*
* C = UnionCsr({A, B});
*
* C = [[0, 1, 2, 0],
* [1, 0, 1, 2],
* [0, 1, 1, 0]]
*
* CSRMatrix_C.num_rows : 3
* CSRMatrix_C.num_cols : 4
*/
CSRMatrix UnionCsr(
const std::vector<CSRMatrix>& csrs);

/*!
* \brief Union a list CSRMatrix into one CSRMatrix.
*
Expand Down
99 changes: 91 additions & 8 deletions src/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,14 +465,6 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArr
return ret;
}

COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOReorder", {
ret = impl::COOReorder<XPU, IdType>(coo, new_row_ids, new_col_ids);
});
return ret;
}

CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRRemove", {
Expand Down Expand Up @@ -509,6 +501,27 @@ COOMatrix CSRRowWiseTopk(
return ret;
}


CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
CSRMatrix ret;
CHECK_GT(csrs.size(), 1) << "UnionCsr creates a union of multiple CSRMatrixes";
// sanity check
for (size_t i = 1; i < csrs.size(); ++i) {
CHECK_EQ(csrs[0].num_rows, csrs[i].num_rows) <<
"UnionCsr requires both CSRMatrix have same number of rows";
CHECK_EQ(csrs[0].num_cols, csrs[i].num_cols) <<
"UnionCsr requires both CSRMatrix have same number of cols";
CHECK_SAME_CONTEXT(csrs[0].indptr, csrs[i].indptr);
CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr);
}

ATEN_CSR_SWITCH(csrs[0], XPU, IdType, "UnionCsr", {
ret = impl::UnionCsr<XPU, IdType>(csrs);
});
return ret;
}


std::tuple<CSRMatrix, IdArray, IdArray>
CSRToSimple(const CSRMatrix& csr) {
std::tuple<CSRMatrix, IdArray, IdArray> ret;
Expand Down Expand Up @@ -645,6 +658,14 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return ret;
}

COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOReorder", {
ret = impl::COOReorder<XPU, IdType>(coo, new_row_ids, new_col_ids);
});
return ret;
}

COOMatrix COORemove(COOMatrix coo, IdArray entries) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COORemove", {
Expand Down Expand Up @@ -689,6 +710,68 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
return ret;
}


COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {
COOMatrix ret;
CHECK_GT(coos.size(), 1) << "UnionCoo creates a union of multiple COOMatrixes";
// sanity check
for (size_t i = 1; i < coos.size(); ++i) {
CHECK_EQ(coos[0].num_rows, coos[i].num_rows) <<
"UnionCoo requires both COOMatrix have same number of rows";
CHECK_EQ(coos[0].num_cols, coos[i].num_cols) <<
"UnionCoo requires both COOMatrix have same number of cols";
CHECK_SAME_CONTEXT(coos[0].row, coos[i].row);
CHECK_SAME_DTYPE(coos[0].row, coos[i].row);
}

// we assume the number of coos is not large in common cases
std::vector<IdArray> coo_row;
std::vector<IdArray> coo_col;
bool has_data = false;

for (size_t i = 0; i < coos.size(); ++i) {
coo_row.push_back(coos[i].row);
coo_col.push_back(coos[i].col);
has_data |= COOHasData(coos[i]);
}

IdArray row = Concat(coo_row);
IdArray col = Concat(coo_col);
IdArray data = NullArray();

if (has_data) {
std::vector<IdArray> eid_data;
eid_data.push_back(COOHasData(coos[0]) ?
coos[0].data :
Range(0,
coos[0].row->shape[0],
coos[0].row->dtype.bits,
coos[0].row->ctx));
int64_t num_edges = coos[0].row->shape[0];
for (size_t i = 1; i < coos.size(); ++i) {
eid_data.push_back(COOHasData(coos[i]) ?
coos[i].data + num_edges :
Range(num_edges,
num_edges + coos[i].row->shape[0],
coos[i].row->dtype.bits,
coos[i].row->ctx));
num_edges += coos[i].row->shape[0];
}

data = Concat(eid_data);
}

return COOMatrix(
coos[0].num_rows,
coos[0].num_cols,
row,
col,
data,
false,
false);
}


std::tuple<COOMatrix, IdArray, IdArray>
COOToSimple(const COOMatrix& coo) {
// coo column sorted
Expand Down
6 changes: 6 additions & 0 deletions src/array/array_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ template <DLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);

// Union CSRMatrixes
template <DLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);

template <DLDeviceType XPU, typename IdType>
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);

Expand Down Expand Up @@ -224,6 +228,8 @@ template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);

///////////////////////// Graph Traverse routines //////////////////////////

template <DLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);

Expand Down
117 changes: 117 additions & 0 deletions src/array/cpu/csr_union.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/coo_sort.cc
* \brief COO sorting
*/
#include <dgl/array.h>

#include <numeric>
#include <algorithm>
#include <vector>
#include <iterator>

namespace dgl {
namespace aten {
namespace impl {

template <DLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
std::vector<IdType> res_indptr;
std::vector<IdType> res_indices;
std::vector<IdType> res_data;

// some preprocess
// we assume the number of csrs is not large in common cases
std::vector<IdArray> data;
std::vector<IdType *> data_data;
std::vector<IdType *> indptr_data;
std::vector<IdType *> indices_data;
int64_t num_edges = 0;
bool sorted = true;
for (size_t i = 0; i < csrs.size(); ++i) {
// eids of csrs[0] remains unchanged
// eids of csrs[1] will be increased by number of edges of csrs[0], etc.
data.push_back(CSRHasData(csrs[i]) ?
csrs[i].data + num_edges:
Range(num_edges,
num_edges + csrs[i].indices->shape[0],
csrs[i].indptr->dtype.bits,
csrs[i].indptr->ctx));
data_data.push_back(data[i].Ptr<IdType>());
indptr_data.push_back(csrs[i].indptr.Ptr<IdType>());
indices_data.push_back(csrs[i].indices.Ptr<IdType>());
num_edges += csrs[i].indices->shape[0];
sorted &= csrs[i].sorted;
}

res_indptr.resize(csrs[0].num_rows + 1);
res_indices.resize(num_edges);
res_data.resize(num_edges);
res_indptr[0] = 0;

if (sorted) { // all csrs are sorted
#pragma omp for
for (int64_t i = 1; i <= csrs[0].num_rows; ++i) {
std::vector<int64_t> indices_off;
res_indptr[i] = indptr_data[0][i];
indices_off.push_back(indptr_data[0][i-1]);
for (size_t j = 1; j < csrs.size(); ++j) {
res_indptr[i] += indptr_data[j][i];
indices_off.push_back(indptr_data[j][i-1]);
}

IdType off = res_indptr[i-1];
while (off < res_indptr[i]) {
IdType min = csrs[0].num_cols + 1;
int64_t min_idx = -1;
for (size_t j = 0; j < csrs.size(); ++j) {
if (indices_off[j] < indptr_data[j][i]) {
if (min <= indices_data[j][indices_off[j]]) {
continue;
} else {
min = indices_data[j][indices_off[j]];
min_idx = j;
}
} // for check out of bound
} // for

res_indices[off] = min;
res_data[off] = data_data[min_idx][indices_off[min_idx]];
indices_off[min_idx] += 1;
++off;
} // while
} // omp for
} else { // some csrs are not sorted
#pragma omp for
for (int64_t i = 1; i <= csrs[0].num_rows; ++i) {
IdType off = res_indptr[i-1];
res_indptr[i] = 0;

for (size_t j = 0; j < csrs.size(); ++j) {
std::memcpy(&res_indices[off],
&indices_data[j][indptr_data[j][i-1]],
sizeof(IdType) * (indptr_data[j][i] - indptr_data[j][i-1]));
std::memcpy(&res_data[off],
&data_data[j][indptr_data[j][i-1]],
sizeof(IdType) * (indptr_data[j][i] - indptr_data[j][i-1]));
off += indptr_data[j][i] - indptr_data[j][i-1];
}
res_indptr[i] = off;
} // omp for
}

return CSRMatrix(
csrs[0].num_rows,
csrs[0].num_cols,
IdArray::FromVector(res_indptr),
IdArray::FromVector(res_indices),
IdArray::FromVector(res_data),
sorted);
}

template CSRMatrix UnionCsr<kDLCPU, int64_t>(const std::vector<CSRMatrix>&);
template CSRMatrix UnionCsr<kDLCPU, int32_t>(const std::vector<CSRMatrix>&);

} // namespace impl
} // namespace aten
} // namespace dgl
Loading

0 comments on commit 27cad32

Please sign in to comment.