diff --git a/src/array/cpu/negative_sampling.cc b/src/array/cpu/negative_sampling.cc index 87560577d6b7..5c8129febfd5 100644 --- a/src/array/cpu/negative_sampling.cc +++ b/src/array/cpu/negative_sampling.cc @@ -27,7 +27,7 @@ std::pair CSRGlobalUniformNegativeSampling( double redundancy) { const int64_t num_row = csr.num_rows; const int64_t num_col = csr.num_cols; - const int64_t num_actual_samples = static_cast(num_samples * redundancy); + const int64_t num_actual_samples = static_cast(num_samples * (1 + redundancy)); IdArray row = Full(-1, num_actual_samples, csr.indptr->ctx); IdArray col = Full(-1, num_actual_samples, csr.indptr->ctx); IdType* row_data = row.Ptr(); diff --git a/src/array/cuda/negative_sampling.cu b/src/array/cuda/negative_sampling.cu index 12feef1341e5..e043ef410a6b 100644 --- a/src/array/cuda/negative_sampling.cu +++ b/src/array/cuda/negative_sampling.cu @@ -140,7 +140,7 @@ std::pair CSRGlobalUniformNegativeSampling( auto dtype = csr.indptr->dtype; const int64_t num_row = csr.num_rows; const int64_t num_col = csr.num_cols; - const int64_t num_actual_samples = static_cast(num_samples * redundancy); + const int64_t num_actual_samples = static_cast(num_samples * (1 + redundancy)); IdArray row = Full(-1, num_actual_samples, ctx); IdArray col = Full(-1, num_actual_samples, ctx); IdArray out_row = IdArray::Empty({num_actual_samples}, dtype, ctx); diff --git a/tests/compute/test_sampling.py b/tests/compute/test_sampling.py index ba763936af90..2938d4d902f2 100644 --- a/tests/compute/test_sampling.py +++ b/tests/compute/test_sampling.py @@ -892,10 +892,10 @@ def test_sample_neighbors_exclude_edges_homoG(dtype): @pytest.mark.parametrize('dtype', ['int32', 'int64']) def test_global_uniform_negative_sampling(dtype): - g = dgl.graph((np.random.randint(0, 20, (10,)), np.random.randint(0, 20, (10,)))).to(F.ctx()) - src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, True) - assert len(src) > 0 - assert len(dst) > 0 + g = dgl.graph(([], []), num_nodes=1000).to(F.ctx()) + src, dst = dgl.sampling.global_uniform_negative_sampling(g, 2000, False, True) + assert len(src) == 2000 + assert len(dst) == 2000 g = dgl.graph((np.random.randint(0, 20, (300,)), np.random.randint(0, 20, (300,)))).to(F.ctx()) src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, True)