Skip to content

Commit

Permalink
[Feature] Sparse-sparse matrix multiplication, addition, and masking (d…
Browse files Browse the repository at this point in the history
…mlc#2753)

* test

* more stuff

* add test

* fixes

* optimize algo

* replace unordered_map with arrays

* lint

* lint x2

* oops

* disable gpu csrmm tests

* remove gpu invocation

* optimize with openmp

* remove python functions

* add back with docstrings

* lint

* lint

* update python interface

* functionize

* functionize

* lint

* lint
  • Loading branch information
BarclayII authored Mar 24, 2021
1 parent d04d59e commit 929d863
Show file tree
Hide file tree
Showing 19 changed files with 989 additions and 104 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ if(BUILD_CPP_TEST)
include_directories("third_party/dlpack/include")
include_directories("third_party/xbyak")
include_directories("third_party/dmlc-core/include")
include_directories("third_party/phmap")
file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/cpp/*.cc)
add_executable(runUnitTests ${TEST_SRC_FILES})
target_link_libraries(runUnitTests gtest gtest_main)
Expand Down
30 changes: 15 additions & 15 deletions include/dgl/aten/spmat.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ enum class SparseFormat {
/*!
* \brief Sparse format codes
*/
const dgl_format_code_t all_code = 0x7;
const dgl_format_code_t coo_code = 0x1;
const dgl_format_code_t csr_code = 0x2;
const dgl_format_code_t csc_code = 0x4;
const dgl_format_code_t ALL_CODE = 0x7;
const dgl_format_code_t COO_CODE = 0x1;
const dgl_format_code_t CSR_CODE = 0x2;
const dgl_format_code_t CSC_CODE = 0x4;

// Parse sparse format from string.
inline SparseFormat ParseSparseFormat(const std::string& name) {
Expand All @@ -55,11 +55,11 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) {

inline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) {
std::vector<SparseFormat> ret;
if (code & coo_code)
if (code & COO_CODE)
ret.push_back(SparseFormat::kCOO);
if (code & csr_code)
if (code & CSR_CODE)
ret.push_back(SparseFormat::kCSR);
if (code & csc_code)
if (code & CSC_CODE)
ret.push_back(SparseFormat::kCSC);
return ret;
}
Expand All @@ -70,13 +70,13 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) {
for (auto format : formats) {
switch (format) {
case SparseFormat::kCOO:
ret |= coo_code;
ret |= COO_CODE;
break;
case SparseFormat::kCSR:
ret |= csr_code;
ret |= CSR_CODE;
break;
case SparseFormat::kCSC:
ret |= csc_code;
ret |= CSC_CODE;
break;
default:
LOG(FATAL) << "Only support COO/CSR/CSC formats.";
Expand All @@ -87,19 +87,19 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) {

inline std::string CodeToStr(dgl_format_code_t code) {
std::string ret = "";
if (code & coo_code)
if (code & COO_CODE)
ret += "coo ";
if (code & csr_code)
if (code & CSR_CODE)
ret += "csr ";
if (code & csc_code)
if (code & CSC_CODE)
ret += "csc ";
return ret;
}

inline SparseFormat DecodeFormat(dgl_format_code_t code) {
if (code & coo_code)
if (code & COO_CODE)
return SparseFormat::kCOO;
if (code & csc_code)
if (code & CSC_CODE)
return SparseFormat::kCSC;
return SparseFormat::kCSR;
}
Expand Down
18 changes: 9 additions & 9 deletions include/dgl/base_heterograph.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ HeteroGraphPtr CreateHeteroGraph(
*/
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, dgl_format_code_t formats = all_code);
IdArray row, IdArray col, dgl_format_code_t formats = ALL_CODE);

/*!
* \brief Create a heterograph from COO input.
Expand All @@ -620,7 +620,7 @@ HeteroGraphPtr CreateFromCOO(
*/
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
dgl_format_code_t formats = all_code);
dgl_format_code_t formats = ALL_CODE);

/*!
* \brief Create a heterograph from CSR input.
Expand All @@ -636,7 +636,7 @@ HeteroGraphPtr CreateFromCOO(
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = all_code);
dgl_format_code_t formats = ALL_CODE);

/*!
* \brief Create a heterograph from CSR input.
Expand All @@ -647,7 +647,7 @@ HeteroGraphPtr CreateFromCSR(
*/
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = all_code);
dgl_format_code_t formats = ALL_CODE);

/*!
* \brief Create a heterograph from CSC input.
Expand All @@ -663,7 +663,7 @@ HeteroGraphPtr CreateFromCSR(
HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = all_code);
dgl_format_code_t formats = ALL_CODE);

/*!
* \brief Create a heterograph from CSC input.
Expand All @@ -674,7 +674,7 @@ HeteroGraphPtr CreateFromCSC(
*/
HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = all_code);
dgl_format_code_t formats = ALL_CODE);

/*!
* \brief Extract the subgraph of the in edges of the given nodes.
Expand Down Expand Up @@ -830,13 +830,13 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph);
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states);

#define FORMAT_HAS_CSC(format) \
((format) & csc_code)
((format) & CSC_CODE)

#define FORMAT_HAS_CSR(format) \
((format) & csr_code)
((format) & CSR_CODE)

#define FORMAT_HAS_COO(format) \
((format) & coo_code)
((format) & COO_CODE)

} // namespace dgl

Expand Down
24 changes: 24 additions & 0 deletions include/dgl/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <string>
#include <vector>
#include <utility>

#include "array.h"
#include "./bcast.h"
Expand Down Expand Up @@ -51,6 +52,29 @@ void SDDMM(const std::string& op,
NDArray efeat,
NDArray out);

/*!
* \brief Sparse-sparse matrix multiplication.
*
* \note B is transposed (i.e. in CSC format).
*/
std::pair<CSRMatrix, NDArray> CSRMM(
CSRMatrix A,
NDArray A_weights,
CSRMatrix B,
NDArray B_weights);

/*!
* \brief Sparse-sparse matrix summation.
*/
std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A,
const std::vector<NDArray>& A_weights);

/*!
* \brief Return a sparse matrix with the values of A but nonzero entry locations of B.
*/
NDArray CSRMask(const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B);

} // namespace aten
} // namespace dgl

Expand Down
121 changes: 120 additions & 1 deletion python/dgl/sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module for sparse matrix operators."""
# pylint: disable= invalid-name
from __future__ import absolute_import
import dgl.ndarray as nd
from . import ndarray as nd
from ._ffi.function import _init_api
from .base import DGLError
from . import backend as F
Expand Down Expand Up @@ -366,5 +366,124 @@ def _bwd_segment_cmp(feat, arg, m):
to_dgl_nd_for_write(out))
return out

class CSRMatrix(object):
"""Device- and backend-agnostic sparse matrix in CSR format.
Parameters
----------
data : Tensor
The data array.
indices : Tensor
The column indices array.
indptr : Tensor
The row index pointer array.
num_rows : int
The number of rows.
num_cols : int
The number of columns.
"""
def __init__(self, data, indices, indptr, num_rows, num_cols):
self.indptr = indptr
self.indices = indices
self.data = data
self.shape = (num_rows, num_cols)

def csrmm(A, B):
"""Sparse-sparse matrix multiplication.
This is an internal function whose interface is subject to changes.
Parameters
----------
A : dgl.sparse.CSRMatrix
The left operand
B : dgl.sparse.CSRMatrix
The right operand
Returns
-------
dgl.sparse.CSRMatrix
The result
"""
A_indptr = F.zerocopy_from_numpy(A.indptr)
A_indices = F.zerocopy_from_numpy(A.indices)
A_data = F.zerocopy_from_numpy(A.data)
B_indptr = F.zerocopy_from_numpy(B.indptr)
B_indices = F.zerocopy_from_numpy(B.indices)
B_data = F.zerocopy_from_numpy(B.data)
C_indptr, C_indices, C_data = _CAPI_DGLCSRMM(
A.shape[0], A.shape[1], B.shape[1],
F.to_dgl_nd(A_indptr),
F.to_dgl_nd(A_indices),
F.to_dgl_nd(A_data),
F.to_dgl_nd(B_indptr),
F.to_dgl_nd(B_indices),
F.to_dgl_nd(B_data))
return CSRMatrix(
F.from_dgl_nd(C_data),
F.from_dgl_nd(C_indices),
F.from_dgl_nd(C_indptr),
A.shape[0],
B.shape[1])

def csrsum(As):
"""Sparse-sparse matrix summation.
This is an internal function whose interface is subject to changes.
Parameters
----------
As : List[dgl.sparse.CSRMatrix]
List of scipy sparse matrices in CSR format.
Returns
-------
dgl.sparse.CSRMatrix
The result
"""
A_indptr = [F.zerocopy_from_numpy(x.indptr) for x in As]
A_indices = [F.zerocopy_from_numpy(x.indices) for x in As]
A_data = [F.zerocopy_from_numpy(x.data) for x in As]
C_indptr, C_indices, C_data = _CAPI_DGLCSRSum(
As[0].shape[0], As[0].shape[1],
[F.to_dgl_nd(x) for x in A_indptr],
[F.to_dgl_nd(x) for x in A_indices],
[F.to_dgl_nd(x) for x in A_data])
return CSRMatrix(
F.from_dgl_nd(C_data),
F.from_dgl_nd(C_indices),
F.from_dgl_nd(C_indptr),
As[0].shape[0], As[0].shape[1])

def csrmask(A, B):
"""Sparse-sparse matrix masking operation that computes ``A[B != 0]``.
This is an internal function whose interface is subject to changes.
Parameters
----------
A : dgl.sparse.CSRMatrix
The left operand
B : dgl.sparse.CSRMatrix
The right operand
Returns
-------
Tensor
The result
"""
A_indptr = F.zerocopy_from_numpy(A.indptr)
A_indices = F.zerocopy_from_numpy(A.indices)
A_data = F.zerocopy_from_numpy(A.data)
B_indptr = F.zerocopy_from_numpy(B.indptr)
B_indices = F.zerocopy_from_numpy(B.indices)
B_data = _CAPI_DGLCSRMask(
A.shape[0], A.shape[1],
F.to_dgl_nd(A_indptr),
F.to_dgl_nd(A_indices),
F.to_dgl_nd(A_data),
F.to_dgl_nd(B_indptr),
F.to_dgl_nd(B_indices))
return F.from_dgl_nd(B_data)

_init_api("dgl.sparse")
Loading

0 comments on commit 929d863

Please sign in to comment.