Skip to content

Commit

Permalink
Bug fix for HBC by feature CPU implementation (pytorch#881)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#881

Add a unit test to compare between CPU/GPU implementation with randomized values - this reveals the bug often enough.

Reviewed By: jianyuh

Differential Revision: D33721542

fbshipit-source-id: 9b0500311562f51e9fe227c4f8b1f091ff9100ac
  • Loading branch information
jasonjk-park authored and facebook-github-bot committed Jan 25, 2022
1 parent 3bc185b commit d41b8aa
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
2 changes: 1 addition & 1 deletion fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,7 @@ void _generic_histogram_binning_calibration_by_feature_cpu_kernel(

const int curr_bin_id =
std::lower_bound(
bin_boundaries, bin_boundaries + num_bins, uncalibrated) -
bin_boundaries, bin_boundaries + num_bins - 1, uncalibrated) -
bin_boundaries;

const int64_t curr_segment_value =
Expand Down
85 changes: 85 additions & 0 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,91 @@ def test_generic_histogram_binning_calibration_by_feature(
)
)

@unittest.skipIf(*gpu_unavailable)
# pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument
@given(
data_type=st.sampled_from([torch.half, torch.float32]),
)
@settings(verbosity=Verbosity.verbose, deadline=None)
def test_generic_histogram_binning_calibration_by_feature_cpu_gpu(
self,
data_type: torch.dtype,
) -> None:
num_logits = random.randint(8, 16)
num_bins = random.randint(3, 8)
num_segments = random.randint(3, 8)
positive_weight = random.uniform(0.1, 1.0)
bin_ctr_in_use_after = random.randint(0, 10)
bin_ctr_weight_value = random.random()

logit = torch.randn(num_logits).type(data_type)

segment_value = torch.randint(
0, num_segments, (random.randint(0, num_logits - 1),)
)
lengths = torch.tensor(
[1] * segment_value.numel() + [0] * (num_logits - segment_value.numel())
)

num_interval = num_bins * (num_segments + 1)
bin_num_positives = torch.randint(0, 10, (num_interval,)).double()
bin_num_examples = (
bin_num_positives + torch.randint(0, 10, (num_interval,)).double()
)

lower_bound = 0.0
upper_bound = 1.0
w = (upper_bound - lower_bound) / num_bins
bin_boundaries = torch.arange(
lower_bound + w, upper_bound - w / 2, w, dtype=torch.float64
)

(
calibrated_prediction_cpu,
bin_ids_cpu,
) = torch.ops.fbgemm.generic_histogram_binning_calibration_by_feature(
logit=logit,
segment_value=segment_value,
segment_lengths=lengths,
num_segments=num_segments,
bin_num_examples=bin_num_examples,
bin_num_positives=bin_num_positives,
bin_boundaries=bin_boundaries,
positive_weight=positive_weight,
bin_ctr_in_use_after=bin_ctr_in_use_after,
bin_ctr_weight_value=bin_ctr_weight_value,
)

(
calibrated_prediction_gpu,
bin_ids_gpu,
) = torch.ops.fbgemm.generic_histogram_binning_calibration_by_feature(
logit=logit.cuda(),
segment_value=segment_value.cuda(),
segment_lengths=lengths.cuda(),
num_segments=num_segments,
bin_num_examples=bin_num_examples.cuda(),
bin_num_positives=bin_num_positives.cuda(),
bin_boundaries=bin_boundaries.cuda(),
positive_weight=positive_weight,
bin_ctr_in_use_after=bin_ctr_in_use_after,
bin_ctr_weight_value=bin_ctr_weight_value,
)

torch.testing.assert_allclose(
calibrated_prediction_cpu,
calibrated_prediction_gpu.cpu(),
rtol=1e-03,
atol=1e-03,
)

self.assertTrue(
torch.equal(
bin_ids_cpu,
bin_ids_gpu.cpu(),
)
)

@settings(verbosity=Verbosity.verbose, deadline=None)
def test_segment_sum_csr(self) -> None:
segment_sum_cpu = torch.ops.fbgemm.segment_sum_csr(
Expand Down

0 comments on commit d41b8aa

Please sign in to comment.