Skip to content

Commit

Permalink
Fix scaling partition info (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvadym authored Jul 14, 2023
1 parent 959be14 commit 4193dd2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
22 changes: 11 additions & 11 deletions analysis/cross_partition_combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,22 +258,22 @@ def _average_utility_report(report: metrics.UtilityReport,
public_partitions: bool,
sums_actual: Tuple) -> None:
"""Averages fields of the 'report' across partitions."""
if not report.metric_errors:
return
partitions = report.partitions_info
if public_partitions:
num_output_partitions = partitions.num_dataset_partitions + partitions.num_empty_partitions
else:
num_output_partitions = partitions.kept_partitions.mean
_multiply_float_dataclasses_field(report.partitions_info,
1.0 / num_output_partitions)
if report.metric_errors:
for sum_actual, metric_error in zip(sums_actual, report.metric_errors):
_multiply_float_dataclasses_field(
metric_error,
1.0 / num_output_partitions,
fields_to_ignore=["noise_std", "ratio_data_dropped"])
scaling_factor = 1 if sum_actual == 0 else 1.0 / sum_actual
_multiply_float_dataclasses_field(metric_error.ratio_data_dropped,
scaling_factor)

for sum_actual, metric_error in zip(sums_actual, report.metric_errors):
_multiply_float_dataclasses_field(
metric_error,
1.0 / num_output_partitions,
fields_to_ignore=["noise_std", "ratio_data_dropped"])
scaling_factor = 1 if sum_actual == 0 else 1.0 / sum_actual
_multiply_float_dataclasses_field(metric_error.ratio_data_dropped,
scaling_factor)


class CrossPartitionCombiner(pipeline_dp.combiners.Combiner):
Expand Down
4 changes: 2 additions & 2 deletions analysis/tests/utility_analysis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def test_wo_public_partitions(self, pre_aggregated: bool):
num_non_public_partitions=None,
num_empty_partitions=None,
strategy=None,
kept_partitions=metrics.MeanVariance(mean=1.0,
var=0.648377588998337)),
kept_partitions=metrics.MeanVariance(mean=3.51622411,
var=2.2798409)),
metric_errors=[
metrics.MetricUtility(
metric=pipeline_dp.Metrics.COUNT,
Expand Down

0 comments on commit 4193dd2

Please sign in to comment.