-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscheduler_apis.cc
44 lines (39 loc) · 1.32 KB
/
scheduler_apis.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/*!
* Copyright (c) 2018 by Contributors
* \file scheduler/scheduler_apis.cc
* \brief DGL scheduler APIs
*/
#include <dgl/array.h>
#include <dgl/graph.h>
#include <dgl/scheduler.h>
#include "../c_api_common.h"
#include "../array/cpu/array_utils.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLRetValue;
using dgl::runtime::NDArray;
namespace dgl {
DGL_REGISTER_GLOBAL("_deprecate.runtime.degree_bucketing._CAPI_DGLDegreeBucketing")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray msg_ids = args[0];
const IdArray vids = args[1];
const IdArray nids = args[2];
CHECK_SAME_DTYPE(msg_ids, vids);
CHECK_SAME_DTYPE(msg_ids, nids);
ATEN_ID_TYPE_SWITCH(msg_ids->dtype, IdType, {
*rv = ConvertNDArrayVectorToPackedFunc(
sched::DegreeBucketing<IdType>(msg_ids, vids, nids));
});
});
DGL_REGISTER_GLOBAL("_deprecate.runtime.degree_bucketing._CAPI_DGLGroupEdgeByNodeDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray uids = args[0];
const IdArray vids = args[1];
const IdArray eids = args[2];
CHECK_SAME_DTYPE(uids, vids);
CHECK_SAME_DTYPE(uids, eids);
ATEN_ID_TYPE_SWITCH(uids->dtype, IdType, {
*rv = ConvertNDArrayVectorToPackedFunc(
sched::GroupEdgeByNodeDegree<IdType>(uids, vids, eids));
});
});
} // namespace dgl