Skip to content

Commit

Permalink
baal-org#130 Add mypy and step to test imports (baal-org#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
Frédéric Branchaud-Charron authored Oct 12, 2021
1 parent a9cc003 commit 82df3d4
Show file tree
Hide file tree
Showing 22 changed files with 320 additions and 189 deletions.
24 changes: 24 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@ jobs:
name: run tests
command: |
make test
test-import:
docker:
# specify the version you desire here
- image: circleci/python:3.7-stretch
steps:
- checkout

- run:
name: install dependencies no-dev
command: |
pip config set global.progress_bar off
pip install --upgrade pip
pip install cmake "poetry==1.1.7"
poetry config virtualenvs.create true
poetry config virtualenvs.in-project true
poetry install --no-interaction --no-ansi --remove-untracked --no-dev
- run:
name: "test import"
command: |
poetry run python -c "import baal; import baal.active.dataset; \
import baal.active.heuristics; import baal.active.active_loop; \
import baal.bayesian; import baal.calibration; import baal.modelwrapper"
fossa:
docker:
# specify the version you desire here
Expand Down Expand Up @@ -80,3 +103,4 @@ workflows:
branches:
only:
- master
- test-import
23 changes: 21 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
lint:
.PHONY: lint
lint: check-mypy-error-count
poetry run flake8 baal

.PHONY: test
test: lint
poetry run pytest tests --cov=baal

.PHONY: format
format:
poetry run black baal

.PHONY: requirements.txt
requirements.txt: poetry.lock
poetry export --without-hashes -f requirements.txt > requirements.txt
poetry export --without-hashes -f requirements.txt > requirements.txt

.PHONY: mypy
mypy:
poetry run mypy --show-error-codes baal


.PHONY: check-mypy-error-count
check-mypy-error-count: MYPY_INFO = $(shell expr `poetry run mypy baal | grep ": error" | wc -l`)
check-mypy-error-count: MYPY_ERROR_COUNT = 16

check-mypy-error-count:
@if [ ${MYPY_INFO} -gt ${MYPY_ERROR_COUNT} ]; then \
echo mypy error count $(MYPY_INFO) is more than $(MYPY_ERROR_COUNT); \
false; \
fi
3 changes: 2 additions & 1 deletion baal/active/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_heuristic(
Returns:
AbstractHeuristic object.
"""
return {
heuristic: heuristics.AbstractHeuristic = {
"random": heuristics.Random,
"certainty": heuristics.Certainty,
"entropy": heuristics.Entropy,
Expand All @@ -31,3 +31,4 @@ def get_heuristic(
"precomputed": heuristics.Precomputed,
"batch_bald": heuristics.BatchBALD,
}[name](shuffle_prop=shuffle_prop, reduction=reduction, **kwargs)
return heuristic
185 changes: 105 additions & 80 deletions baal/active/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from copy import deepcopy
from itertools import zip_longest
from typing import Union, Optional, Callable, Tuple, List, Any
from typing import Union, Optional, Callable, Tuple, List, Any, Dict

import numpy as np
import torch
Expand All @@ -13,41 +13,98 @@ def _identity(x):
return x


class ActiveLearningDataset(torchdata.Dataset):
class SplittedDataset(torchdata.Dataset):
"""Abstract class for Dataset that can be splitted."""

labelled: np.ndarray
random_state: np.random.RandomState
label: Callable

def is_labelled(self, idx: int) -> bool:
"""Check if a datapoint is labelled."""
return bool(self.labelled[idx].item() == 1)

def __len__(self) -> int:
"""Return how many actual data / label pairs we have."""
return int(self.labelled.sum())

@property
def n_unlabelled(self):
"""The number of unlabelled data points."""
return (~self.labelled).sum()

@property
def n_labelled(self):
"""The number of labelled data points."""
return self.labelled.sum()

def label_randomly(self, n: int = 1) -> None:
"""
Label `n` data-points randomly.
Args:
n (int): Number of samples to label.
"""
for i in range(n):
"""Making multiple call to self.n_unlabelled is inefficient, but
self.label changes the available length and it may lead to
IndexError if not done this way."""
self.label(self.random_state.choice(self.n_unlabelled, 1).item())

""" This returns one or zero, if it is labelled or not, no index is returned.
"""

def _labelled_to_oracle_index(self, index: int) -> int:
return int(self.labelled.nonzero()[0][index].squeeze().item())

def _pool_to_oracle_index(self, index: Union[int, List[int]]) -> List[int]:
if isinstance(index, np.int64) or isinstance(index, int):
index = [index]

lbl_nz = (~self.labelled).nonzero()[0]
return [int(lbl_nz[idx].squeeze().item()) for idx in index]

def _oracle_to_pool_index(self, index: Union[int, List[int]]) -> List[int]:
if isinstance(index, int):
index = [index]

# Pool indices are the unlabelled, starts at 0
lbl_cs = np.cumsum(~self.labelled) - 1
return [int(lbl_cs[idx].squeeze().item()) for idx in index]


class ActiveLearningDataset(SplittedDataset):
"""A dataset that allows for active learning.
Args:
dataset (torch.data.Dataset): The baseline dataset.
labelled (Union[np.ndarray, torch.Tensor]):
An array/tensor that acts as a boolean mask which is True for every
dataset: The baseline dataset.
labelled: An array that acts as a boolean mask which is True for every
data point that is labelled, and False for every data point that is not
labelled.
make_unlabelled (Callable): The function that returns an
make_unlabelled: The function that returns an
unlabelled version of a datum so that it can still be used in the DataLoader.
random_state (None, int, RandomState): Set the random seed for label_randomly().
pool_specifics (Optional[Dict]): Attributes to set when creating the pool.
random_state: Set the random seed for label_randomly().
pool_specifics: Attributes to set when creating the pool.
Useful to remove data augmentation.
"""

def __init__(
self,
dataset: torchdata.Dataset,
labelled: Union[np.ndarray, torch.Tensor] = None,
labelled: Optional[np.ndarray] = None,
make_unlabelled: Callable = _identity,
random_state=None,
pool_specifics: Optional[dict] = None,
) -> None:
self._dataset = dataset
if labelled is not None:
if isinstance(labelled, torch.Tensor):
labelled = labelled.numpy()
self.labelled = labelled.astype(bool)
self.labelled: np.ndarray = labelled.astype(bool)
else:
self.labelled = np.zeros(len(self._dataset), dtype=bool)

if pool_specifics is None:
pool_specifics = {}
self.pool_specifics = pool_specifics
self.pool_specifics: Dict[str, Any] = pool_specifics

self.make_unlabelled = make_unlabelled
# For example, FileDataset has a method 'label'. This is useful when we're in prod.
Expand Down Expand Up @@ -85,14 +142,10 @@ def check_dataset_can_label(self):
)
return False

def __getitem__(self, index: int) -> Tuple[torch.Tensor, ...]:
def __getitem__(self, index: int) -> Any:
"""Return stuff from the original dataset."""
return self._dataset[self._labelled_to_oracle_index(index)]

def __len__(self) -> int:
"""Return how many actual data / label pairs we have."""
return self.labelled.sum()

class ActiveIter:
"""Iterator over an ActiveLearningDataset."""

Expand All @@ -114,16 +167,6 @@ def __next__(self):
def __iter__(self):
return self.ActiveIter(self)

@property
def n_unlabelled(self):
"""The number of unlabelled data points."""
return (~self.labelled).sum()

@property
def n_labelled(self):
"""The number of labelled data points."""
return self.labelled.sum()

@property
def pool(self) -> torchdata.Dataset:
"""Returns a new Dataset made from unlabelled samples.
Expand All @@ -145,35 +188,14 @@ def pool(self) -> torchdata.Dataset:
ald = ActiveLearningPool(pool_dataset, make_unlabelled=self.make_unlabelled)
return ald

""" This returns one or zero, if it is labelled or not, no index is returned.
"""

def _labelled_to_oracle_index(self, index: int) -> int:
return self.labelled.nonzero()[0][index].squeeze().item()

def _pool_to_oracle_index(self, index: Union[int, List[int]]) -> List[int]:
if isinstance(index, np.int64) or isinstance(index, int):
index = [index]

lbl_nz = (~self.labelled).nonzero()[0]
return [int(lbl_nz[idx].squeeze().item()) for idx in index]

def _oracle_to_pool_index(self, index: Union[int, List[int]]) -> List[int]:
if isinstance(index, int):
index = [index]

# Pool indices are the unlabelled, starts at 0
lbl_cs = np.cumsum(~self.labelled) - 1
return [int(lbl_cs[idx].squeeze().item()) for idx in index]

def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
"""
Label data points.
The index should be relative to the pool, not the overall data.
Args:
index (Union[list,int]): one or many indices to label.
value (Optional[Any]): The label value. If not provided, no modification
index: one or many indices to label.
value: The label value. If not provided, no modification
to the underlying dataset is done.
"""
if isinstance(index, int):
Expand All @@ -183,7 +205,7 @@ def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
indexes = self._pool_to_oracle_index(index)
for index, val in zip_longest(indexes, value, fillvalue=None):
if self.can_label and val is not None:
self._dataset.label(index, val)
self._dataset.label(index, val) # type: ignore
self.labelled[index] = 1
elif self.can_label and val is None:
warnings.warn(
Expand All @@ -204,32 +226,15 @@ def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
UserWarning,
)

def label_randomly(self, n: int = 1) -> None:
"""
Label `n` data-points randomly.
Args:
n (int): Number of samples to label.
"""
for i in range(n):
"""Making multiple call to self.n_unlabelled is inefficient, but
self.label changes the available length and it may lead to
IndexError if not done this way."""
self.label(self.random_state.choice(self.n_unlabelled, 1).item())

def reset_labeled(self):
"""Reset the label map."""
self.labelled = np.zeros(len(self._dataset), dtype=np.bool)

def is_labelled(self, idx: int) -> bool:
"""Check if a datapoint is labelled."""
return self.labelled[idx] == 1

def get_raw(self, idx: int) -> None:
def get_raw(self, idx: int) -> Any:
"""Get a datapoint from the underlying dataset."""
return self._dataset[idx]

def state_dict(self):
def state_dict(self) -> Dict:
"""Return the state_dict, ie. the labelled map and random_state."""
return {"labelled": self.labelled, "random_state": self.random_state}

Expand All @@ -250,10 +255,10 @@ class ActiveLearningPool(torchdata.Dataset):
"""

def __init__(self, dataset: torchdata.Dataset, make_unlabelled: Callable = _identity) -> None:
self._dataset = dataset
self._dataset: torchdata.Dataset = dataset
self.make_unlabelled = make_unlabelled

def __getitem__(self, index: int) -> Tuple[torch.Tensor, ...]:
def __getitem__(self, index: int) -> Any:
# This datum is marked as unlabelled, so clear the label.
return self.make_unlabelled(self._dataset[index])

Expand All @@ -262,7 +267,7 @@ def __len__(self) -> int:
return len(self._dataset)


class ActiveNumpyArray(ActiveLearningDataset):
class ActiveNumpyArray(SplittedDataset):
"""
Active dataset for numpy arrays. Useful when using sklearn.
Expand All @@ -277,16 +282,15 @@ class ActiveNumpyArray(ActiveLearningDataset):
def __init__(
self,
dataset: Tuple[np.ndarray, np.ndarray],
labelled: Union[np.ndarray, torch.Tensor] = None,
labelled: Optional[np.ndarray] = None,
) -> None:

self.random_state = np.random.RandomState()
if labelled is not None:
if isinstance(labelled, torch.Tensor):
labelled = labelled.numpy()
labelled = labelled.astype(bool)
else:
labelled = np.zeros(len(dataset[0]), dtype=bool)
super().__init__(dataset, labelled=labelled)
self._dataset: Tuple[np.ndarray, np.ndarray] = dataset
self.labelled: np.ndarray = labelled

@property
def pool(self):
Expand All @@ -298,8 +302,29 @@ def dataset(self):
"""Return the labelled portion of the dataset."""
return self._dataset[0][self.labelled], self._dataset[1][self.labelled]

def get_raw(self, idx: int) -> None:
def get_raw(self, idx: int) -> Any:
return self._dataset[0][idx], self._dataset[1][idx]

def __iter__(self):
return zip(*self._dataset)

def __getitem__(self, item):
return self._dataset[0][item], self._dataset[1][item]

def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
"""
Label data points.
The index should be relative to the pool, not the overall data.
Args:
index (Union[list,int]): one or many indices to label.
value (Optional[Any]): The label value. If not provided, no modification
to the underlying dataset is done.
"""
if isinstance(index, int):
index = [index]
if not isinstance(value, (list, tuple)):
value = [value]
indexes = self._pool_to_oracle_index(index)
for index, val in zip_longest(indexes, value, fillvalue=None):
self.labelled[index] = 1
Loading

0 comments on commit 82df3d4

Please sign in to comment.