forked from microsoft/project-azua
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmissforest.py
59 lines (47 loc) · 1.85 KB
/
missforest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import Any, Dict, Optional, Callable, Tuple
import numpy as np
# Explicitly enable experimental IterativeImputer (new in scikit-learn 0.22.2)
from sklearn.experimental import enable_iterative_imputer # noqa F401
from sklearn.impute import IterativeImputer
from sklearn.ensemble import ExtraTreesRegressor
from ..baselines.sk_learn_imputer import SKLearnImputer
from ..datasets.dataset import Dataset
from ..datasets.variables import Variables
class MissForest(SKLearnImputer):
def __init__(
self, model_id: str, variables: Variables, save_dir, max_iter=10, initial_strategy="mean", random_seed=0
):
imputer = IterativeImputer(
max_iter=max_iter,
initial_strategy=initial_strategy,
random_state=random_seed,
estimator=ExtraTreesRegressor(),
)
super().__init__(model_id, variables, save_dir, imputer)
@classmethod
def name(cls) -> str:
return "missforest"
def run_train(
self,
dataset: Dataset,
train_config_dict: Optional[Dict[str, Any]] = None,
report_progress_callback: Optional[Callable[[str, int, int], None]] = None,
) -> None:
data, mask = dataset.train_data_and_mask
data = self.fill_mask(data, mask)
self._imputer.fit(data)
def impute(
self,
data: np.ndarray,
mask: np.ndarray,
impute_config_dict: Optional[Dict[str, int]] = None,
vamp_prior_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
average: bool = True,
) -> np.ndarray:
data = self.fill_mask(data, mask)
row_count, feature_count = data.shape
imputed = self._imputer.transform(data)
if not average:
# Add extra dimension that would be used for sampling
imputed = np.expand_dims(imputed, axis=0)
return imputed