forked from facebookresearch/ReAgent
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnormalizers.py
61 lines (49 loc) · 1.88 KB
/
normalizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import collections
import logging
import numpy as np
from reagent.core.parameters import NormalizationParameters
logger = logging.getLogger(__name__)
def normalizer_helper(feats, feature_type, min_value=None, max_value=None):
assert feature_type in (
"DISCRETE_ACTION",
"CONTINUOUS",
"CONTINUOUS_ACTION",
), f"invalid feature type: {feature_type}."
assert type(min_value) == type(max_value) and type(min_value) in (
int,
float,
list,
np.ndarray,
type(None),
), f"invalid {type(min_value)}, {type(max_value)}"
if type(min_value) in [int, float, type(None)]:
min_value = [min_value] * len(feats)
max_value = [max_value] * len(feats)
normalization = collections.OrderedDict(
[
(
feats[i],
NormalizationParameters(
feature_type=feature_type,
boxcox_lambda=None,
boxcox_shift=None,
mean=0,
stddev=1,
possible_values=None,
quantiles=None,
min_value=float(min_value[i]) if min_value[i] is not None else None,
max_value=float(max_value[i]) if max_value[i] is not None else None,
),
)
for i in range(len(feats))
]
)
return normalization
def discrete_action_normalizer(feats):
return normalizer_helper(feats, "DISCRETE_ACTION")
def only_continuous_normalizer(feats, min_value=None, max_value=None):
return normalizer_helper(feats, "CONTINUOUS", min_value, max_value)
def only_continuous_action_normalizer(feats, min_value=None, max_value=None):
return normalizer_helper(feats, "CONTINUOUS_ACTION", min_value, max_value)