Skip to content

Commit

Permalink
feat: support custom (de)normalization of property values in RT (GT4S…
Browse files Browse the repository at this point in the history
  • Loading branch information
jannisborn authored May 4, 2023
1 parent 6e050ed commit 5d76aa4
Showing 1 changed file with 46 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@ def load_inference(self, resources_path: str) -> None:
data.get("property_ranges", {}).get(p, [0, 1])[1]
for p in self.properties
]
# In case custom normalization/denormalization is required
self.normalization_fns = [
data.get("normalization_fns", {}).get(p, None) for p in self.properties
]
self.denormalization_fns = [
data.get("denormalization_fns", {}).get(p, None)
for p in self.properties
]
self.metadata = data

# If tolerance dict is given, ensure it is well-formed
Expand Down Expand Up @@ -211,14 +219,25 @@ def denormalize(self, x: float, idx: int, precision: int = 4) -> float:
Returns:
float: Value in regular scale.
"""

# If the property was not normalized, return the value
if not self.do_normalize[idx]:
return x

return round(
x * (self._maxs[idx] - self._mins[idx]) + self._mins[idx], precision
)
# The default normalization reverts a linear transformation to [0,1] scale
if self.denormalization_fns[idx] is None:
return round(
x * (self._maxs[idx] - self._mins[idx]) + self._mins[idx], precision
)

# This allows to revert arbitrarily complex preprocessing functions
fn = self.denormalization_fns[idx]
try:
denormed = eval(fn)(x)
return round(denormed, precision)
except SyntaxError:
raise SyntaxError(
f"Custom denormalization function {fn} seems improperly formatted"
)

def normalize(self, x: str, idx: int, precision: int = 3) -> float:
"""
Expand All @@ -237,13 +256,32 @@ def normalize(self, x: str, idx: int, precision: int = 3) -> float:
raise TypeError(f"{x} is not a float and cant safely be casted.")

x_float = float(x)
if x_float < self._mins[idx] or x_float > self._maxs[idx]:
raise ValueError(
f"Property value {x_float} for {self.properties[idx]} is outside of "
f"model's range [{self._mins[idx]}, {self._maxs[idx]}]."
)
# If this property does not require normalization, return it
if not self.do_normalize[idx]:
return x_float
normed = round(
(x_float - self._mins[idx]) / (self._maxs[idx] - self._mins[idx]), precision
)
return normed

# This performs a standard linear normalization to [0,1]
if self.normalization_fns[idx] is None:
normed = round(
(x_float - self._mins[idx]) / (self._maxs[idx] - self._mins[idx]),
precision,
)
return normed

# Allows to apply arbitrary preprocessing functions
fn = self.normalization_fns[idx]
try:
normed = eval(fn)(x_float)
return round(normed, precision)
except SyntaxError:
raise SyntaxError(
f"Custom normalization function {fn} seems improperly formatted"
)

def validate_input(self, x: str) -> None:
"""
Expand Down

0 comments on commit 5d76aa4

Please sign in to comment.