Skip to content

Commit

Permalink
[NodeFlow] Non-uniform neighbor sampling (dmlc#711)
Browse files Browse the repository at this point in the history
* nonuniform sampler

* unit test

* test on out neighbors

* error checks

* lint

* fix

* clarification

* use macro switcher

* use empty array for uniform sampling

* oops

* Revert "oops"

This reverts commit a11f9ae.

* Revert "use empty array for uniform sampling"

This reverts commit 8526ce4.

* re-reverting

* use a method
  • Loading branch information
BarclayII authored Aug 7, 2019
1 parent 742d79a commit 1606192
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 54 deletions.
13 changes: 8 additions & 5 deletions include/dgl/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@ class SamplerOp {
* \param num_hops the number of hops to sample neighbors.
* \param expand_factor the max number of neighbors to sample.
* \param add_self_loop whether to add self loop to the sampled subgraph
* \param probability the transition probability (float/double).
* \return a NodeFlow graph.
*/
static NodeFlow NeighborUniformSample(const ImmutableGraph *graph,
const std::vector<dgl_id_t>& seeds,
const std::string &edge_type,
int num_hops, int expand_factor,
const bool add_self_loop);
template<typename ValueType>
static NodeFlow NeighborSample(const ImmutableGraph *graph,
const std::vector<dgl_id_t>& seeds,
const std::string &edge_type,
int num_hops, int expand_factor,
const bool add_self_loop,
const ValueType *probability);

/*!
* \brief Sample a graph from the seed vertices with layer sampling.
Expand Down
47 changes: 38 additions & 9 deletions python/dgl/contrib/sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def batch_size(self):
return self._batch_size

class NeighborSampler(NodeFlowSampler):
'''Create a sampler that samples neighborhood.
r'''Create a sampler that samples neighborhood.
It returns a generator of :class:`~dgl.NodeFlow`. This can be viewed as
an analogy of *mini-batch training* on graph data -- the given graph represents
Expand Down Expand Up @@ -258,10 +258,30 @@ class NeighborSampler(NodeFlowSampler):
* "out": the neighbors on the out-edges.
Default: "in"
node_prob : Tensor, optional
A 1D tensor for the probability that a neighbor node is sampled.
None means uniform sampling. Otherwise, the number of elements
should be equal to the number of vertices in the graph.
transition_prob : str, optional
A 1D tensor containing the (unnormalized) transition probability.
The probability of a node v being sampled from a neighbor u is proportional to
the edge weight, normalized by the sum over edge weights grouping by the
destination node.
In other words, given a node v, the probability of node u and edge (u, v)
included in the NodeFlow layer preceding that of v is given by:
.. math::
p(u, v) = \frac{w_{u, v}}{\sum_{u', (u', v) \in E} w_{u', v}}
If neighbor type is "out", then the probability is instead normalized by the sum
grouping by source node:
.. math::
p(v, u) = \frac{w_{v, u}}{\sum_{u', (v, u') \in E} w_{v, u'}}
If a str is given, the edge weight will be loaded from the edge feature column with
the same name. The feature column must be a scalar column in this case.
Default: None
seed_nodes : Tensor, optional
A 1D tensor list of nodes where we sample NodeFlows from.
Expand All @@ -287,7 +307,7 @@ def __init__(
expand_factor=None,
num_hops=1,
neighbor_type='in',
node_prob=None,
transition_prob=None,
seed_nodes=None,
shuffle=False,
num_workers=1,
Expand All @@ -299,17 +319,24 @@ def __init__(

assert g.is_readonly, "NeighborSampler doesn't support mutable graphs. " + \
"Please turn it into an immutable graph with DGLGraph.readonly"
assert node_prob is None, 'non-uniform node probability not supported'
assert isinstance(expand_factor, Integral), 'non-int expand_factor not supported'

self._expand_factor = int(expand_factor)
self._num_hops = int(num_hops)
self._add_self_loop = add_self_loop
self._num_workers = int(num_workers)
self._neighbor_type = neighbor_type
self._transition_prob = transition_prob

def fetch(self, current_nodeflow_index):
nfobjs = _CAPI_UniformSampling(
if self._transition_prob is None:
prob = F.tensor([], F.float32)
elif isinstance(self._transition_prob, str):
prob = self.g.edata[self._transition_prob]
else:
prob = self._transition_prob

nfobjs = _CAPI_NeighborSampling(
self.g._graph,
self.seed_nodes.todgltensor(),
current_nodeflow_index, # start batch id
Expand All @@ -318,7 +345,9 @@ def fetch(self, current_nodeflow_index):
self._expand_factor,
self._num_hops,
self._neighbor_type,
self._add_self_loop)
self._add_self_loop,
F.zerocopy_to_dgl_ndarray(prob))

nflows = [NodeFlow(self.g, obj) for obj in nfobjs]
return nflows

Expand Down
31 changes: 24 additions & 7 deletions src/array/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
namespace dgl {
namespace aten {

#define ATEN_XPU_SWITCH(val, XPU, ...) \
#define ATEN_XPU_SWITCH(val, XPU, ...) do { \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Device type: " << (val) << " is not supported."; \
}
} \
} while (0)

#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) \
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do { \
CHECK_EQ((val).code, kDLInt) << "ID must be integer type"; \
if ((val).bits == 32) { \
typedef int32_t IdType; \
Expand All @@ -26,10 +27,25 @@ namespace aten {
typedef int64_t IdType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "ID can Only be int32 or int64"; \
}
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)

#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) \
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do { \
CHECK_EQ((val).code, kDLFloat) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be float32 or float64"; \
} \
} while (0)

#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
Expand All @@ -38,7 +54,8 @@ namespace aten {
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "CSR matrix data can only be int32 or int64"; \
}
} \
} while (0)

// Macro to dispatch according to device context, index type and data type
// TODO(minjie): In our current use cases, data type and id type are the
Expand Down
Loading

0 comments on commit 1606192

Please sign in to comment.