Skip to content

Commit

Permalink
speed up random walks (dmlc#3158)
Browse files Browse the repository at this point in the history
Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
BarclayII and VoVAllen authored Jul 20, 2021
1 parent d1cc096 commit ddc92f8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
25 changes: 16 additions & 9 deletions src/graph/sampling/randomwalks/metapath_randomwalk.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
dgl_id_t curr,
int64_t len,
const std::vector<CSRMatrix> &edges_by_type,
const std::vector<bool> &csr_has_data,
const IdxType *metapath_data,
const std::vector<FloatArray> &prob,
TerminatePredicate<IdxType> terminate) {
Expand All @@ -70,7 +71,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
const CSRMatrix &csr = edges_by_type[etype];
const IdxType *offsets = csr.indptr.Ptr<IdxType>();
const IdxType *all_succ = csr.indices.Ptr<IdxType>();
const IdxType *all_eids = CSRHasData(csr) ? csr.data.Ptr<IdxType>() : nullptr;
const IdxType *all_eids = csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr;
const IdxType *succ = all_succ + offsets[curr];
const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;

Expand Down Expand Up @@ -124,6 +125,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
dgl_id_t curr,
int64_t len,
const std::vector<CSRMatrix> &edges_by_type,
const std::vector<bool> &csr_has_data,
const IdxType *metapath_data,
const std::vector<FloatArray> &prob,
TerminatePredicate<IdxType> terminate) {
Expand All @@ -137,7 +139,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
const CSRMatrix &csr = edges_by_type[etype];
const IdxType *offsets = csr.indptr.Ptr<IdxType>();
const IdxType *all_succ = csr.indices.Ptr<IdxType>();
const IdxType *all_eids = CSRHasData(csr) ? csr.data.Ptr<IdxType>() : nullptr;
const IdxType *all_eids = csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr;
const IdxType *succ = all_succ + offsets[curr];
const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;

Expand Down Expand Up @@ -179,9 +181,14 @@ std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
// This forces the heterograph to materialize all OutCSR's before the OpenMP loop;
// otherwise data races will happen.
// TODO(BarclayII): should we later on materialize COO/CSR/CSC anyway unless told otherwise?
std::vector<CSRMatrix> edges_by_type;
for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype)
edges_by_type.push_back(hg->GetCSRMatrix(etype));
int64_t num_etypes = hg->NumEdgeTypes();
std::vector<CSRMatrix> edges_by_type(num_etypes);
std::vector<bool> csr_has_data(num_etypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const CSRMatrix &csr = hg->GetCSRMatrix(etype);
edges_by_type[etype] = csr;
csr_has_data[etype] = CSRHasData(csr);
}

// Hoist the check for Uniform vs Non uniform edge distribution
// to avoid putting it on the hot path
Expand All @@ -194,18 +201,18 @@ std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
}
if (!isUniform) {
StepFunc<IdxType> step =
[&edges_by_type, metapath_data, &prob, terminate]
[&edges_by_type, &csr_has_data, metapath_data, &prob, terminate]
(IdxType *data, dgl_id_t curr, int64_t len) {
return MetapathRandomWalkStep<XPU, IdxType>(
data, curr, len, edges_by_type, metapath_data, prob, terminate);
data, curr, len, edges_by_type, csr_has_data, metapath_data, prob, terminate);
};
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step);
} else {
StepFunc<IdxType> step =
[&edges_by_type, metapath_data, &prob, terminate]
[&edges_by_type, &csr_has_data, metapath_data, &prob, terminate]
(IdxType *data, dgl_id_t curr, int64_t len) {
return MetapathRandomWalkStepUniform<XPU, IdxType>(
data, curr, len, edges_by_type, metapath_data, prob, terminate);
data, curr, len, edges_by_type, csr_has_data, metapath_data, prob, terminate);
};
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step);
}
Expand Down
9 changes: 5 additions & 4 deletions src/graph/sampling/randomwalks/node2vec_randomwalk.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ bool has_edge_between(const CSRMatrix &csr, dgl_id_t u,
template <DLDeviceType XPU, typename IdxType>
std::tuple<dgl_id_t, dgl_id_t, bool> Node2vecRandomWalkStep(
IdxType *data, dgl_id_t curr, dgl_id_t pre, const double p, const double q,
int64_t len, const CSRMatrix &csr, const FloatArray &probs,
int64_t len, const CSRMatrix &csr, bool csr_has_data, const FloatArray &probs,
TerminatePredicate<IdxType> terminate) {
const IdxType *offsets = csr.indptr.Ptr<IdxType>();
const IdxType *all_succ = csr.indices.Ptr<IdxType>();
const IdxType *all_eids = CSRHasData(csr) ? csr.data.Ptr<IdxType>() : nullptr;
const IdxType *all_eids = csr_has_data ? csr.data.Ptr<IdxType>() : nullptr;
const IdxType *succ = all_succ + offsets[curr];
const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;

Expand Down Expand Up @@ -153,13 +153,14 @@ std::pair<IdArray, IdArray> Node2vecRandomWalk(
const int64_t max_num_steps, const FloatArray &prob,
TerminatePredicate<IdxType> terminate) {
const CSRMatrix &edges = g->GetCSRMatrix(0); // homogeneous graph.
bool csr_has_data = CSRHasData(edges);

StepFunc<IdxType> step =
[&edges, &prob, p, q, terminate]
[&edges, csr_has_data, &prob, p, q, terminate]
(IdxType *data, dgl_id_t curr, int64_t len) {
dgl_id_t pre = (len != 0) ? data[len - 1] : curr;
return Node2vecRandomWalkStep<XPU, IdxType>(data, curr, pre, p, q, len,
edges, prob, terminate);
edges, csr_has_data, prob, terminate);
};

return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step);
Expand Down

0 comments on commit ddc92f8

Please sign in to comment.