Skip to content

Commit

Permalink
[Feature] Graceful handling of exceptions thrown within OpenMP blocks (
Browse files Browse the repository at this point in the history
…dmlc#3353)

* graceful c++ exception in OpenMP

* credits

* add test

Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
BarclayII and VoVAllen authored Sep 22, 2021
1 parent bc14829 commit a04a8d0
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 5 deletions.
14 changes: 13 additions & 1 deletion include/dgl/runtime/parallel_for.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <algorithm>
#include <string>
#include <cstdlib>
#include <exception>
#include <atomic>

namespace {
int64_t divup(int64_t x, int64_t y) {
Expand Down Expand Up @@ -67,6 +69,9 @@ void parallel_for(

#ifdef _OPENMP
auto num_threads = compute_num_threads(begin, end, grain_size);
// (BarclayII) the exception code is borrowed from PyTorch.
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
std::exception_ptr eptr;

#pragma omp parallel num_threads(num_threads)
{
Expand All @@ -75,9 +80,16 @@ void parallel_for(
auto begin_tid = begin + tid * chunk_size;
if (begin_tid < end) {
auto end_tid = std::min(end, chunk_size + begin_tid);
f(begin_tid, end_tid);
try {
f(begin_tid, end_tid);
} catch (...) {
if (!err_flag.test_and_set())
eptr = std::current_exception();
}
}
}
if (eptr)
std::rethrow_exception(eptr);
#else
f(begin, end);
#endif
Expand Down
6 changes: 4 additions & 2 deletions src/graph/sampling/randomwalks/metapath_randomwalk.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
TerminatePredicate<IdxType> terminate) {
int64_t max_num_steps = metapath->shape[0];
const IdxType *metapath_data = static_cast<IdxType *>(metapath->data);
const int64_t begin_ntype = hg->meta_graph()->FindEdge(metapath_data[0]).first;
const int64_t max_nodes = hg->NumVertices(begin_ntype);

// Prefetch all edges.
// This forces the heterograph to materialize all OutCSR's before the OpenMP loop;
Expand Down Expand Up @@ -206,15 +208,15 @@ std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
return MetapathRandomWalkStep<XPU, IdxType>(
data, curr, len, edges_by_type, csr_has_data, metapath_data, prob, terminate);
};
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step);
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step, max_nodes);
} else {
StepFunc<IdxType> step =
[&edges_by_type, &csr_has_data, metapath_data, &prob, terminate]
(IdxType *data, dgl_id_t curr, int64_t len) {
return MetapathRandomWalkStepUniform<XPU, IdxType>(
data, curr, len, edges_by_type, csr_has_data, metapath_data, prob, terminate);
};
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step);
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step, max_nodes);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/graph/sampling/randomwalks/node2vec_randomwalk.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ std::pair<IdArray, IdArray> Node2vecRandomWalk(
edges, csr_has_data, prob, terminate);
};

return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step);
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step, g->NumVertices(0));
}

}; // namespace
Expand Down
6 changes: 5 additions & 1 deletion src/graph/sampling/randomwalks/randomwalks_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ namespace {
* edge type in the metapath.
* \param max_num_steps The maximum number of steps of a random walk path.
* \param step The random walk step function with type \c StepFunc.
* \param max_nodes Throws an error if one of the values in \c seeds exceeds this argument.
* \return A 2D array of shape (len(seeds), max_num_steps + 1) with node IDs.
* \note The graph itself should be bounded in the closure of \c step.
*/
template<DLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> GenericRandomWalk(
const IdArray seeds,
int64_t max_num_steps,
StepFunc<IdxType> step) {
StepFunc<IdxType> step,
int64_t max_nodes) {
int64_t num_seeds = seeds->shape[0];
int64_t trace_length = max_num_steps + 1;
IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, seeds->ctx);
Expand All @@ -54,6 +56,8 @@ std::pair<IdArray, IdArray> GenericRandomWalk(
dgl_id_t curr = seed_data[seed_id];
traces_data[seed_id * trace_length] = curr;

CHECK_LT(curr, max_nodes) << "Seed node ID exceeds the maximum number of nodes.";

for (i = 0; i < max_num_steps; ++i) {
const auto &succ = step(traces_data + seed_id * max_num_steps, curr, i);
traces_data[seed_id * trace_length + i + 1] = curr = std::get<0>(succ);
Expand Down
6 changes: 6 additions & 0 deletions tests/compute/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def test_random_walk():

traces, eids, ntypes = dgl.sampling.random_walk(g1, [0, 1, 2, 0, 1, 2], length=4, return_eids=True)
check_random_walk(g1, ['follow'] * 4, traces, ntypes, trace_eids=eids)
try:
dgl.sampling.random_walk(g1, [0, 1, 2, 10], length=4, return_eids=True)
fail = False # shouldn't abort
except:
fail = True
assert fail
traces, eids, ntypes = dgl.sampling.random_walk(g1, [0, 1, 2, 0, 1, 2], length=4, restart_prob=0., return_eids=True)
check_random_walk(g1, ['follow'] * 4, traces, ntypes, trace_eids=eids)
traces, ntypes = dgl.sampling.random_walk(
Expand Down

0 comments on commit a04a8d0

Please sign in to comment.