Skip to content

Commit

Permalink
Add partition selection strategy for UtilityReport (#478)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvadym authored Aug 7, 2023
1 parent 0dee86d commit bde02b7
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
19 changes: 18 additions & 1 deletion analysis/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import copy
import dataclasses
from typing import Iterable, Optional, Sequence
from typing import Iterable, Optional, Sequence, Union

import pipeline_dp
from pipeline_dp import input_validators
Expand Down Expand Up @@ -132,3 +132,20 @@ def get_aggregate_params(
for i in range(multi_param_configuration.size):
yield multi_param_configuration.get_aggregate_params(
options.aggregate_params, i)


def get_partition_selection_strategy(
options: UtilityAnalysisOptions
) -> Sequence[pipeline_dp.PartitionSelectionStrategy]:
"""Returns partition selection strategies for different configurations."""
multi_configuration = options.multi_param_configuration
n_configurations = 1
if multi_configuration is not None:
if multi_configuration.partition_selection_strategy is not None:
# Different parameter configurations have different partition
# selection strategies.
return multi_configuration.partition_selection_strategy
n_configurations = multi_configuration.size
# The same partition selection strategy for all configuration.
return [options.aggregate_params.partition_selection_strategy
] * n_configurations
3 changes: 2 additions & 1 deletion analysis/tests/utility_analysis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def test_wo_public_partitions(self, pre_aggregated: bool):
num_dataset_partitions=10,
num_non_public_partitions=None,
num_empty_partitions=None,
strategy=None,
strategy=pipeline_dp.PartitionSelectionStrategy.
TRUNCATED_GEOMETRIC,
kept_partitions=metrics.MeanVariance(mean=3.51622411,
var=2.2798409)),
metric_errors=[
Expand Down
17 changes: 16 additions & 1 deletion analysis/utility_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pipeline_dp
from pipeline_dp import pipeline_backend
import analysis
from analysis import data_structures
from analysis import metrics
from analysis import utility_analysis_engine
from analysis import cross_partition_combiners
Expand Down Expand Up @@ -114,6 +115,21 @@ def perform_utility_analysis(
"Compute cross-partition metrics")
# ((configuration_index, bucket), UtilityReport)

if public_partitions is None:
# Add partition selection strategy for private partitions.
strategies = data_structures.get_partition_selection_strategy(options)

def add_partition_selection_strategy(report: metrics.UtilityReport):
# Beam does not allow to change input arguments in map, so copy it.
report = copy.deepcopy(report)
report.partitions_info.strategy = strategies[
report.configuration_index]
return report

cross_partition_metrics = backend.map_values(
cross_partition_metrics, add_partition_selection_strategy,
"Add Partition Selection Strategy")

cross_partition_metrics = backend.map_tuple(
cross_partition_metrics, lambda key, value: (key[0], (key[1], value)),
"Rekey")
Expand All @@ -125,7 +141,6 @@ def perform_utility_analysis(
result = backend.map_tuple(cross_partition_metrics, _group_utility_reports,
"Group utility reports")
# result: (UtilityReport)

return result, per_partition_result


Expand Down

0 comments on commit bde02b7

Please sign in to comment.