Skip to content

Commit

Permalink
[Bug Fix]edge sample hotfix (dmlc#1152)
Browse files Browse the repository at this point in the history
* hot fix

* Fix docs

* Fix ArrayHeap float overflow bug

* Fix

* Clean some dead code

* Fix

* FIx

* Add some comments

* run test
  • Loading branch information
classicsong authored and zheng-da committed Dec 31, 2019
1 parent 6731ea3 commit 61b78e6
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 15 deletions.
10 changes: 6 additions & 4 deletions python/dgl/contrib/sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,10 @@ class EdgeSampler(object):
edges and reset the replacement state. If it is set to false, the sampler will only
generate num_edges/batch_size samples.
Note: If node_weight is extremely imbalanced, the sampler will take much longer
time to return a minibatch, as sampled negative nodes must not be duplicated for
one corruptted positive edge.
Parameters
----------
g : DGLGraph
Expand Down Expand Up @@ -737,11 +741,9 @@ def fetch(self, current_index):
def __iter__(self):
it = SamplerIter(self)
if self._is_uniform:
subgs = _CAPI_ResetUniformEdgeSample(
self._sampler)
_CAPI_ResetUniformEdgeSample(self._sampler)
else:
subgs = _CAPI_ResetWeightedEdgeSample(
self._sampler)
_CAPI_ResetWeightedEdgeSample(self._sampler)

if self._num_prefetch:
return self._prefetching_wrapper_class(it, self._num_prefetch)
Expand Down
30 changes: 19 additions & 11 deletions src/graph/sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ class ArrayHeap {
*/
void Delete(size_t index) {
size_t i = index + limit_;
ValueType w = heap_[i];
for (int j = bit_len_; j >= 0; --j) {
heap_[i] -= w;
i = i >> 1;
heap_[i] = 0;
i /= 2;
for (int j = bit_len_-1; j >= 0; --j) {
// Using heap_[i] = heap_[i] - w will loss some precision in float.
// Using addition to re-calculate the weight layer by layer.
heap_[i] = heap_[i << 1] + heap_[(i << 1) + 1];
i /= 2;
}
}

Expand Down Expand Up @@ -1480,12 +1483,15 @@ class UniformEdgeSamplerObject: public EdgeSamplerObject {
sizeof(dgl_id_t) * start);
} else {
std::vector<dgl_id_t> seeds;
const dgl_id_t *seed_edge_ids = static_cast<const dgl_id_t *>(seed_edges_->data);
// sampling of each edge is a standalone event
for (int64_t i = 0; i < num_edges; ++i) {
seeds.push_back(RandomEngine::ThreadLocal()->RandInt(num_seeds_));
int64_t seed = static_cast<const int64_t>(
RandomEngine::ThreadLocal()->RandInt(num_seeds_));
seeds.push_back(seed_edge_ids[seed]);
}

worker_seeds = aten::VecToIdArray(seeds);
worker_seeds = aten::VecToIdArray(seeds, seed_edges_->dtype.bits);
}

EdgeArray arr = gptr_->FindEdges(worker_seeds);
Expand Down Expand Up @@ -1674,7 +1680,6 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
curr_batch_id_ = 0;
// handle int64 overflow here
max_batch_id_ = (num_edges + batch_size - 1) / batch_size;

// TODO(song): Tricky thing here to make sure gptr_ has coo cache
gptr_->FindEdge(0);
}
Expand All @@ -1697,9 +1702,12 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
size_t n = batch_size_;
size_t num_ids = 0;
#pragma omp critical
num_ids = edge_selector_->SampleWithoutReplacement(n, &edge_ids);
while (edge_ids.size() > num_ids) {
edge_ids.pop_back();
{
num_ids = edge_selector_->SampleWithoutReplacement(n, &edge_ids);
}
edge_ids.resize(num_ids);
for (size_t i = 0; i < num_ids; ++i) {
edge_ids[i] = seed_edge_ids[edge_ids[i]];
}
} else {
// sampling of each edge is a standalone event
Expand All @@ -1708,6 +1716,7 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
edge_ids[i] = seed_edge_ids[edge_id];
}
}

auto worker_seeds = aten::VecToIdArray(edge_ids, seed_edges_->dtype.bits);

EdgeArray arr = gptr_->FindEdges(worker_seeds);
Expand All @@ -1716,7 +1725,6 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
std::vector<dgl_id_t> src_vec(src_ids, src_ids + batch_size_);
std::vector<dgl_id_t> dst_vec(dst_ids, dst_ids + batch_size_);
// TODO(zhengda) what if there are duplicates in the src and dst vectors.

Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false);
positive_subgs[i] = ConvertRef(subg);
// For PBG negative sampling, we accept "PBG-head" for corrupting head
Expand Down
33 changes: 33 additions & 0 deletions tests/compute/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,38 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size):
assert np.allclose(node_rate, node_rate_a * 5, atol=0.002)
assert np.allclose(node_rate_a, node_rate_b, atol=0.0002)

def check_positive_edge_sampler():
g = generate_rand_graph(1000)
num_edges = g.number_of_edges()
edge_weight = F.copy_to(F.tensor(np.full((num_edges,), 1, dtype=np.float32)), F.cpu())

edge_weight[num_edges-1] = num_edges ** 2
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')

# Correctness check
# Test the homogeneous graph.
batch_size = 128
edge_sampled = np.full((num_edges,), 0, dtype=np.int32)
for pos_edges in EdgeSampler(g, batch_size,
reset=False,
edge_weight=edge_weight):
_, _, pos_leid = pos_edges.all_edges(form='all', order='eid')
np.add.at(edge_sampled, F.asnumpy(pos_edges.parent_eid[pos_leid]), 1)
truth = np.full((num_edges,), 1, dtype=np.int32)
edge_sampled = edge_sampled[:num_edges]
assert np.array_equal(truth, edge_sampled)

edge_sampled = np.full((num_edges,), 0, dtype=np.int32)
for pos_edges in EdgeSampler(g, batch_size,
reset=False,
shuffle=True,
edge_weight=edge_weight):
_, _, pos_leid = pos_edges.all_edges(form='all', order='eid')
np.add.at(edge_sampled, F.asnumpy(pos_edges.parent_eid[pos_leid]), 1)
truth = np.full((num_edges,), 1, dtype=np.int32)
edge_sampled = edge_sampled[:num_edges]
assert np.array_equal(truth, edge_sampled)


@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support item assignment")
def test_negative_sampler():
Expand All @@ -674,6 +706,7 @@ def test_negative_sampler():
check_weighted_negative_sampler('PBG-head', False, 10)
check_weighted_negative_sampler('head', True, 10)
check_weighted_negative_sampler('head', False, 10)
check_positive_edge_sampler()
#disable this check for now. It might take too long time.
#check_negative_sampler('head', False, 100)

Expand Down

0 comments on commit 61b78e6

Please sign in to comment.