Skip to content

Commit

Permalink
[Feature] Add dgl.utils.is_sorted_srcdst() (dmlc#2685)
Browse files Browse the repository at this point in the history
* Add dgl.utils.is_sorted_srcdst

* Fix linting issues

* delete blank line

* Specify datatype to index tensor in test

* Force integer conversion

Co-authored-by: Minjie Wang <[email protected]>
Co-authored-by: Quan (Andy) Gan <[email protected]>
  • Loading branch information
3 people authored Jul 2, 2021
1 parent 95f8ec8 commit a0390dd
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
40 changes: 40 additions & 0 deletions python/dgl/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import, division

from ..base import DGLError
from .._ffi.function import _init_api
from .. import backend as F

def prepare_tensor(g, data, name):
Expand Down Expand Up @@ -166,3 +167,42 @@ def check_valid_idtype(idtype):
if idtype not in [None, F.int32, F.int64]:
raise DGLError('Expect idtype to be a framework object of int32/int64, '
'got {}'.format(idtype))

def is_sorted_srcdst(src, dst, num_src=None, num_dst=None):
"""Checks whether an edge list is in ascending src-major order (e.g., first
sorted by ``src`` and then by ``dst``).
Parameters
----------
src : IdArray
The tensor of source nodes for each edge.
dst : IdArray
The tensor of destination nodes for each edge.
num_src : int, optional
The number of source nodes.
num_dst : int, optional
The number of destination nodes.
Returns
-------
bool, bool
Whether ``src`` is in ascending order, and whether ``dst`` is
in ascending order with respect to ``src``.
"""
# for some versions of MXNET and TensorFlow, num_src and num_dst get
# incorrectly marked as floats, so force them as integers here
if num_src is None:
num_src = int(F.as_scalar(F.max(src, dim=0)+1))
if num_dst is None:
num_dst = int(F.as_scalar(F.max(dst, dim=0)+1))

src = F.zerocopy_to_dgl_ndarray(src)
dst = F.zerocopy_to_dgl_ndarray(dst)
sorted_status = _CAPI_DGLCOOIsSorted(src, dst, num_src, num_dst)

row_sorted = sorted_status > 0
col_sorted = sorted_status > 1

return row_sorted, col_sorted

_init_api("dgl.utils.checks")
27 changes: 27 additions & 0 deletions src/runtime/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@

#include <dmlc/omp.h>


#include <dgl/aten/coo.h>
#include <dgl/packed_func_ext.h>
#include <utility>

#include "../c_api_common.h"
#include "../array/array_op.h"


using namespace dgl::runtime;
using namespace dgl::aten::impl;

namespace dgl {

Expand All @@ -19,4 +26,24 @@ DGL_REGISTER_GLOBAL("utils.internal._CAPI_DGLSetOMPThreads")
omp_set_num_threads(num_threads);
});


DGL_REGISTER_GLOBAL("utils.checks._CAPI_DGLCOOIsSorted")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
IdArray src = args[0];
IdArray dst = args[1];
int64_t num_src = args[2];
int64_t num_dst = args[3];

bool row_sorted, col_sorted;
std::tie(row_sorted, col_sorted) = COOIsSorted(
aten::COOMatrix(num_src, num_dst, src, dst));

// make sure col_sorted is only true when row_sorted is true
assert(!(!row_sorted && col_sorted));

// 0 for unosrted, 1 for row sorted, 2 for row and col sorted
int64_t sorted_status = row_sorted + col_sorted;
*rv = sorted_status;
});

} // namespace dgl
26 changes: 26 additions & 0 deletions tests/compute/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,36 @@ def test_empty_data_initialized():
assert "ha" in g.ndata
assert len(g.ndata["ha"]) == 1

def test_is_sorted():
u_src, u_dst = edge_pair_input(False)
s_src, s_dst = edge_pair_input(True)

u_src = F.tensor(u_src, dtype=F.int32)
u_dst = F.tensor(u_dst, dtype=F.int32)
s_src = F.tensor(s_src, dtype=F.int32)
s_dst = F.tensor(s_dst, dtype=F.int32)

src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(u_src, u_dst)
assert src_sorted == False
assert dst_sorted == False

src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(s_src, s_dst)
assert src_sorted == True
assert dst_sorted == True

src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(u_src, u_dst)
assert src_sorted == False
assert dst_sorted == False

src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(s_src, u_dst)
assert src_sorted == True
assert dst_sorted == False

if __name__ == '__main__':
test_query()
test_mutation()
test_scipy_adjmat()
test_incmat()
test_find_edges()
test_hypersparse_query()
test_is_sorted()

0 comments on commit a0390dd

Please sign in to comment.