Skip to content

Commit

Permalink
Deprecate shuffle prop in favor of PowerSampling (baal-org#260)
Browse files Browse the repository at this point in the history
* Deprecate shuffle prop in favor of PowerSampling

* Update baal/active/heuristics/heuristics.py

Co-authored-by: Parmida Atighehchian <[email protected]>

---------

Co-authored-by: Parmida Atighehchian <[email protected]>
  • Loading branch information
Dref360 and parmidaatg authored May 19, 2023
1 parent b218a56 commit 149718d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
31 changes: 21 additions & 10 deletions baal/active/heuristics/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@

from baal.utils.array_utils import to_prob

DEPRECATED = "DEPRECATED"
SHUFFLE_PROP_DEPRECATION_NOTICE = """
`shuffle_prop` is deprecated and will be removed in Baal 1.9.0.
For better batch uncertainty estimation, use `baal.active.heuristics.stochastics.PowerSampling`.
See `https://baal.readthedocs.io/en/latest/user_guide/heuristics/` for more details.
"""

available_reductions = {
"max": lambda x: np.max(x, axis=tuple(range(1, x.ndim))),
"min": lambda x: np.min(x, axis=tuple(range(1, x.ndim))),
Expand Down Expand Up @@ -139,7 +146,11 @@ class AbstractHeuristic:
reduction (Union[str, Callable]): Reduction used after computing the score.
"""

def __init__(self, shuffle_prop=0.0, reverse=False, reduction="none"):
def __init__(self, shuffle_prop=DEPRECATED, reverse=False, reduction="none"):
if shuffle_prop != DEPRECATED and shuffle_prop < 1.0:
warnings.warn(SHUFFLE_PROP_DEPRECATION_NOTICE, DeprecationWarning)
else:
shuffle_prop = 0.0
self.shuffle_prop = shuffle_prop
self.reversed = reverse
assert reduction in available_reductions or callable(reduction)
Expand Down Expand Up @@ -272,7 +283,7 @@ class BALD(AbstractHeuristic):
https://arxiv.org/abs/1703.02910
"""

def __init__(self, shuffle_prop=0.0, reduction="none"):
def __init__(self, shuffle_prop=DEPRECATED, reduction="none"):
super().__init__(shuffle_prop=shuffle_prop, reverse=True, reduction=reduction)

@require_single_item
Expand Down Expand Up @@ -324,7 +335,7 @@ class BatchBALD(BALD):
Not tested on 4+ dims.
"""

def __init__(self, num_samples, num_draw=500, shuffle_prop=0.0, reduction="none"):
def __init__(self, num_samples, num_draw=500, shuffle_prop=DEPRECATED, reduction="none"):
self.epsilon = 1e-5
self.num_samples = num_samples
self.num_draw = num_draw
Expand Down Expand Up @@ -508,7 +519,7 @@ class Variance(AbstractHeuristic):
reduction (Union[str, callable]): function that aggregates the results (default: `mean`).
"""

def __init__(self, shuffle_prop=0.0, reduction="mean"):
def __init__(self, shuffle_prop=DEPRECATED, reduction="mean"):
_help = "Need to reduce the output from [n_sample, n_class] to [n_sample]"
assert reduction != "none", _help
super().__init__(shuffle_prop=shuffle_prop, reverse=True, reduction=reduction)
Expand All @@ -529,7 +540,7 @@ class Entropy(AbstractHeuristic):
reduction (Union[str, callable]): function that aggregates the results (default: `none`).
"""

def __init__(self, shuffle_prop=0.0, reduction="none"):
def __init__(self, shuffle_prop=DEPRECATED, reduction="none"):
super().__init__(shuffle_prop=shuffle_prop, reverse=True, reduction=reduction)

@require_single_item
Expand All @@ -551,7 +562,7 @@ class Margin(AbstractHeuristic):
(default: `none`).
"""

def __init__(self, shuffle_prop=0.0, reduction="none"):
def __init__(self, shuffle_prop=DEPRECATED, reduction="none"):
super().__init__(shuffle_prop=shuffle_prop, reverse=False, reduction=reduction)

@require_single_item
Expand All @@ -571,7 +582,7 @@ class Certainty(AbstractHeuristic):
reduction (Union[str, callable]): function that aggregates the results.
"""

def __init__(self, shuffle_prop=0.0, reduction="none"):
def __init__(self, shuffle_prop=DEPRECATED, reduction="none"):
super().__init__(shuffle_prop=shuffle_prop, reverse=False, reduction=reduction)

@require_single_item
Expand All @@ -588,7 +599,7 @@ class Precomputed(AbstractHeuristic):
reverse (Bool): Sort from lowest to highest if False.
"""

def __init__(self, shuffle_prop=0.0, reverse=False):
def __init__(self, shuffle_prop=DEPRECATED, reverse=False):
super().__init__(shuffle_prop, reverse=reverse)

def compute_score(self, predictions):
Expand All @@ -604,7 +615,7 @@ class Random(Precomputed):
seed (Optional[int]): If provided, will seed the random generator.
"""

def __init__(self, shuffle_prop=0.0, reduction="none", seed=None):
def __init__(self, shuffle_prop=DEPRECATED, reduction="none", seed=None):
super().__init__(1.0, False)
if seed is not None:
self.rng = np.random.RandomState(seed)
Expand Down Expand Up @@ -643,7 +654,7 @@ class CombineHeuristics(AbstractHeuristic):
"""

def __init__(self, heuristics: List, weights: List, reduction="mean", shuffle_prop=0.0):
def __init__(self, heuristics: List, weights: List, reduction="mean", shuffle_prop=DEPRECATED):
super(CombineHeuristics, self).__init__(reduction=reduction, shuffle_prop=shuffle_prop)
self.composed_heuristic = heuristics
self.weights = weights
Expand Down
14 changes: 14 additions & 0 deletions tests/active/heuristic_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
import pytest
from hypothesis import given
Expand Down Expand Up @@ -332,6 +334,18 @@ def test_heuristic_reductio_check(distributions):
heuristic(distributions)
assert "Can't order sequence with more than 1 dimension." in str(e_info.value)

def test_shuffle_prop_warning():
with warnings.catch_warnings(record=True) as tape:
_ = BALD()
assert len(tape) == 0
_ = BALD(shuffle_prop=.1)
assert len(tape) == 1 and "shuffle_prop" in str(tape[0].message)\
and isinstance(tape[0].message, DeprecationWarning)

# Random doesn't raise the warning
with warnings.catch_warnings(record=True) as tape:
_ = Random()
assert len(tape) == 0

if __name__ == '__main__':
pytest.main()

0 comments on commit 149718d

Please sign in to comment.