Skip to content

Commit

Permalink
[Bugfix] Fixes the redundancy parameter being used wrong in global ne…
Browse files Browse the repository at this point in the history
…gative sampling (dmlc#3657)

* oops

* test
  • Loading branch information
BarclayII authored Jan 17, 2022
1 parent 48cbea7 commit 77f4287
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/array/cpu/negative_sampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
double redundancy) {
const int64_t num_row = csr.num_rows;
const int64_t num_col = csr.num_cols;
const int64_t num_actual_samples = static_cast<int64_t>(num_samples * redundancy);
const int64_t num_actual_samples = static_cast<int64_t>(num_samples * (1 + redundancy));
IdArray row = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
IdArray col = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
IdType* row_data = row.Ptr<IdType>();
Expand Down
2 changes: 1 addition & 1 deletion src/array/cuda/negative_sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
auto dtype = csr.indptr->dtype;
const int64_t num_row = csr.num_rows;
const int64_t num_col = csr.num_cols;
const int64_t num_actual_samples = static_cast<int64_t>(num_samples * redundancy);
const int64_t num_actual_samples = static_cast<int64_t>(num_samples * (1 + redundancy));
IdArray row = Full<IdType>(-1, num_actual_samples, ctx);
IdArray col = Full<IdType>(-1, num_actual_samples, ctx);
IdArray out_row = IdArray::Empty({num_actual_samples}, dtype, ctx);
Expand Down
8 changes: 4 additions & 4 deletions tests/compute/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,10 +892,10 @@ def test_sample_neighbors_exclude_edges_homoG(dtype):

@pytest.mark.parametrize('dtype', ['int32', 'int64'])
def test_global_uniform_negative_sampling(dtype):
g = dgl.graph((np.random.randint(0, 20, (10,)), np.random.randint(0, 20, (10,)))).to(F.ctx())
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, True)
assert len(src) > 0
assert len(dst) > 0
g = dgl.graph(([], []), num_nodes=1000).to(F.ctx())
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 2000, False, True)
assert len(src) == 2000
assert len(dst) == 2000

g = dgl.graph((np.random.randint(0, 20, (300,)), np.random.randint(0, 20, (300,)))).to(F.ctx())
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, True)
Expand Down

0 comments on commit 77f4287

Please sign in to comment.