forked from databrickslabs/dbldatagen
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature constraints (databrickslabs#257)
* feature for constraints * lint related updates * work in progress * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * updated build actions * Update CHANGELOG.md * wip * wip * wip * wip
- Loading branch information
1 parent
b28602d
commit 82ce5ce
Showing
20 changed files
with
1,125 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""" | ||
This package defines the constraints classes for the `dbldatagen` library. | ||
The constraints classes are used to define predefined constraints that may be used to constrain the generated data. | ||
Constraining the generated data is implemented in several ways: | ||
- Rejection of rows that do not meet the criteria | ||
- Modifying the generated data to meet the constraint (including modifying the data generation parameters) | ||
Some constraints may be implemented using a combination of the above approaches. | ||
For implementations using the rejection approach, the data generation process will possibly generate less than the | ||
requested number of rows. | ||
For the current implementation, most of the constraint strategies will be implemented using rejection based criteria. | ||
""" | ||
|
||
from .chained_relation import ChainedRelation | ||
from .constraint import Constraint | ||
from .literal_range_constraint import LiteralRange | ||
from .literal_relation_constraint import LiteralRelation | ||
from .negative_values import NegativeValues | ||
from .positive_values import PositiveValues | ||
from .ranged_values_constraint import RangedValues | ||
from .sql_expr import SqlExpr | ||
from .unique_combinations import UniqueCombinations | ||
|
||
__all__ = ["chained_relation", | ||
"constraint", | ||
"negative_values", | ||
"literal_range_constraint", | ||
"literal_relation_constraint", | ||
"positive_values", | ||
"ranged_values_constraint", | ||
"unique_combinations"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""" | ||
This module defines the ChainedInequality class | ||
""" | ||
from pyspark.sql import DataFrame | ||
import pyspark.sql.functions as F | ||
from .constraint import Constraint, NoPrepareTransformMixin | ||
|
||
|
||
class ChainedRelation(NoPrepareTransformMixin, Constraint): | ||
"""ChainedRelation constraint | ||
Constrains one or more columns so that each column has a relationship to the next. | ||
For example if the constraint is defined as `ChainedRelation(['a', 'b','c'], "<")` then only rows that | ||
satisfy the condition `a < b < c` will be included in the output | ||
(where `a`, `b` and `c` represent the data values for the rows). | ||
This can be used to model time related transactions (for example in retail where the purchaseDate, shippingDate | ||
and returnDate all have a specific relationship) etc. | ||
Relations supported include <, <=, >=, >, !=, == | ||
:param columns: column name or list of column names as string or list of strings | ||
:param relation: operator to check - should be one of <,> , =,>=,<=, ==, != | ||
""" | ||
def __init__(self, columns, relation): | ||
super().__init__(supportsStreaming=True) | ||
self._relation = relation | ||
self._columns = self._columnsFromListOrString(columns) | ||
|
||
if relation not in self.SUPPORTED_OPERATORS: | ||
raise ValueError(f"Parameter `relation` should be one of the operators :{self.SUPPORTED_OPERATORS}") | ||
|
||
if not isinstance(self._columns, list) or len(self._columns) <= 1: | ||
raise ValueError("ChainedRelation constraints must be defined across more than one column") | ||
|
||
def _generateFilterExpression(self): | ||
""" Generated composite filter expression for chained set of filter expressions | ||
I.e if columns is ['a', 'b', 'c'] and relation is '<' | ||
create set of filters [ col('a') < col('b'), col('b') < col('c')] | ||
and combine them as single expression using logical and operation | ||
:return: filter expression for chained expressions | ||
""" | ||
expressions = [F.col(colname) for colname in self._columns] | ||
|
||
filters = [] | ||
# build set of filters for chained expressions | ||
for ix in range(1, len(expressions)): | ||
filters.append(self._generate_relation_expression(expressions[ix - 1], self._relation, expressions[ix])) | ||
|
||
# ... and combine them using logical `and` operation | ||
return self.mkCombinedConstraintExpression(filters) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""" | ||
This module defines the Constraint class | ||
""" | ||
import types | ||
from abc import ABC, abstractmethod | ||
from pyspark.sql import Column | ||
|
||
|
||
class Constraint(ABC): | ||
""" Constraint object - base class for predefined and custom constraints | ||
This class is meant for internal use only. | ||
""" | ||
SUPPORTED_OPERATORS = ["<", ">", ">=", "!=", "==", "=", "<=", "<>"] | ||
|
||
def __init__(self, supportsStreaming=False): | ||
""" | ||
Initialize the constraint object | ||
""" | ||
self._filterExpression = None | ||
self._calculatedFilterExpression = False | ||
self._supportsStreaming = supportsStreaming | ||
|
||
@staticmethod | ||
def _columnsFromListOrString(columns): | ||
""" Get columns as list of columns from string of list-like | ||
:param columns: string or list of strings representing column names | ||
""" | ||
if isinstance(columns, str): | ||
return [columns] | ||
elif isinstance(columns, (list, set, tuple, types.GeneratorType)): | ||
return list(columns) | ||
else: | ||
raise ValueError("Columns must be a string or list of strings") | ||
|
||
@staticmethod | ||
def _generate_relation_expression(column, relation, valueExpression): | ||
""" Generate comparison expression | ||
:param column: Column to generate comparison against | ||
:param relation: relation to implement | ||
:param valueExpression: expression to compare to | ||
:return: relation expression as variation of Pyspark SQL columns | ||
""" | ||
if relation == ">": | ||
return column > valueExpression | ||
elif relation == ">=": | ||
return column >= valueExpression | ||
elif relation == "<": | ||
return column < valueExpression | ||
elif relation == "<=": | ||
return column <= valueExpression | ||
elif relation in ["!=", "<>"]: | ||
return column != valueExpression | ||
elif relation in ["=", "=="]: | ||
return column == valueExpression | ||
else: | ||
raise ValueError(f"Unsupported relation type '{relation}") | ||
|
||
@staticmethod | ||
def mkCombinedConstraintExpression(constraintExpressions): | ||
""" Generate a SQL expression that combines multiple constraints using AND | ||
:param constraintExpressions: list of Pyspark SQL Column constraint expression objects | ||
:return: combined constraint expression as Pyspark SQL Column object (or None if no valid expressions) | ||
""" | ||
assert constraintExpressions is not None and isinstance(constraintExpressions, list), \ | ||
"Constraints must be a list of Pyspark SQL Column instances" | ||
|
||
assert all(expr is None or isinstance(expr, Column) for expr in constraintExpressions), \ | ||
"Constraint expressions must be Pyspark SQL columns or None" | ||
|
||
valid_constraint_expressions = [expr for expr in constraintExpressions if expr is not None] | ||
|
||
if len(valid_constraint_expressions) > 0: | ||
combined_constraint_expression = valid_constraint_expressions[0] | ||
|
||
for additional_constraint in valid_constraint_expressions[1:]: | ||
combined_constraint_expression = combined_constraint_expression & additional_constraint | ||
|
||
return combined_constraint_expression | ||
else: | ||
return None | ||
|
||
@abstractmethod | ||
def prepareDataGenerator(self, dataGenerator): | ||
""" Prepare the data generator to generate data that matches the constraint | ||
This method may modify the data generation rules to meet the constraint | ||
:param dataGenerator: Data generation object that will generate the dataframe | ||
:return: modified or unmodified data generator | ||
""" | ||
raise NotImplementedError("Method prepareDataGenerator must be implemented in derived class") | ||
|
||
@abstractmethod | ||
def transformDataframe(self, dataGenerator, dataFrame): | ||
""" Transform the dataframe to make data conform to constraint if possible | ||
This method should not modify the dataGenerator - but may modify the dataframe | ||
:param dataGenerator: Data generation object that generated the dataframe | ||
:param dataFrame: generated dataframe | ||
:return: modified or unmodified Spark dataframe | ||
The default transformation returns the dataframe unmodified | ||
""" | ||
raise NotImplementedError("Method transformDataframe must be implemented in derived class") | ||
|
||
@abstractmethod | ||
def _generateFilterExpression(self): | ||
""" Generate a Pyspark SQL expression that may be used for filtering""" | ||
raise NotImplementedError("Method _generateFilterExpression must be implemented in derived class") | ||
|
||
@property | ||
def supportsStreaming(self): | ||
""" Return True if the constraint supports streaming dataframes""" | ||
return self._supportsStreaming | ||
|
||
@property | ||
def filterExpression(self): | ||
""" Return the filter expression (as instance of type Column that evaluates to True or non-True)""" | ||
if not self._calculatedFilterExpression: | ||
self._filterExpression = self._generateFilterExpression() | ||
self._calculatedFilterExpression = True | ||
return self._filterExpression | ||
|
||
|
||
class NoFilterMixin: | ||
""" Mixin class to indicate that constraint has no filter expression | ||
Intended to be used in implementation of the concrete constraint classes. | ||
Use of the mixin class is optional but when used with the Constraint class and multiple inheritance, | ||
it will provide a default implementation of the _generateFilterExpression method that satisfies | ||
the abstract method requirement of the Constraint class. | ||
When using mixins, place the mixin class first in the list of base classes. | ||
""" | ||
def _generateFilterExpression(self): | ||
""" Generate a Pyspark SQL expression that may be used for filtering""" | ||
return None | ||
|
||
|
||
class NoPrepareTransformMixin: | ||
""" Mixin class to indicate that constraint has no filter expression | ||
Intended to be used in implementation of the concrete constraint classes. | ||
Use of the mixin class is optional but when used with the Constraint class and multiple inheritance, | ||
it will provide a default implementation of the `prepareDataGenerator` and `transformeDataFrame` methods | ||
that satisfies the abstract method requirements of the Constraint class. | ||
When using mixins, place the mixin class first in the list of base classes. | ||
""" | ||
def prepareDataGenerator(self, dataGenerator): | ||
""" Prepare the data generator to generate data that matches the constraint | ||
This method may modify the data generation rules to meet the constraint | ||
:param dataGenerator: Data generation object that will generate the dataframe | ||
:return: modified or unmodified data generator | ||
""" | ||
return dataGenerator | ||
|
||
def transformDataframe(self, dataGenerator, dataFrame): | ||
""" Transform the dataframe to make data conform to constraint if possible | ||
This method should not modify the dataGenerator - but may modify the dataframe | ||
:param dataGenerator: Data generation object that generated the dataframe | ||
:param dataFrame: generated dataframe | ||
:return: modified or unmodified Spark dataframe | ||
The default transformation returns the dataframe unmodified | ||
""" | ||
return dataFrame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""" | ||
This module defines the ScalarRange class | ||
""" | ||
import pyspark.sql.functions as F | ||
|
||
from .constraint import Constraint, NoPrepareTransformMixin | ||
|
||
|
||
class LiteralRange(NoPrepareTransformMixin, Constraint): | ||
""" LiteralRange Constraint object - validates that column value(s) are between 2 literal values | ||
:param columns: Name of column or list of column names | ||
:param lowValue: Tests that columns have values greater than low value (greater or equal if `strict` is False) | ||
:param highValue: Tests that columns have values less than high value (less or equal if `strict` is False) | ||
:param strict: If True, excludes low and high values from range. Defaults to False | ||
Note `lowValue` and `highValue` must be values that can be converted to a literal expression using the | ||
`pyspark.sql.functions.lit` function | ||
""" | ||
|
||
def __init__(self, columns, lowValue, highValue, strict=False): | ||
super().__init__(supportsStreaming=True) | ||
self._columns = self._columnsFromListOrString(columns) | ||
self._lowValue = lowValue | ||
self._highValue = highValue | ||
self._strict = strict | ||
|
||
def _generateFilterExpression(self): | ||
""" Generate a SQL filter expression that may be used for filtering""" | ||
expressions = [F.col(colname) for colname in self._columns] | ||
minValue = F.lit(self._lowValue) | ||
maxValue = F.lit(self._highValue) | ||
|
||
# build ranged comparison expressions | ||
if self._strict: | ||
filters = [(column_expr > minValue) & (column_expr < maxValue) for column_expr in expressions] | ||
else: | ||
filters = [column_expr.between(minValue, maxValue) for column_expr in expressions] | ||
|
||
# ... and combine them using logical `and` operation | ||
return self.mkCombinedConstraintExpression(filters) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""" | ||
This module defines the ScalarInequality class | ||
""" | ||
import pyspark.sql.functions as F | ||
|
||
from .constraint import Constraint, NoPrepareTransformMixin | ||
|
||
|
||
class LiteralRelation(NoPrepareTransformMixin, Constraint): | ||
"""LiteralRelation constraint | ||
Constrains one or more columns so that the columns have an a relationship to a constant value | ||
:param columns: column name or list of column names | ||
:param relation: operator to check - should be one of <,> , =,>=,<=, ==, != | ||
:param value: A literal value to to compare against | ||
""" | ||
|
||
def __init__(self, columns, relation, value): | ||
super().__init__(supportsStreaming=True) | ||
self._columns = self._columnsFromListOrString(columns) | ||
self._relation = relation | ||
self._value = value | ||
|
||
if relation not in self.SUPPORTED_OPERATORS: | ||
raise ValueError(f"Parameter `relation` should be one of the operators :{self.SUPPORTED_OPERATORS}") | ||
|
||
def _generateFilterExpression(self): | ||
expressions = [F.col(colname) for colname in self._columns] | ||
literalValue = F.lit(self._value) | ||
filters = [self._generate_relation_expression(col, self._relation, literalValue) for col in expressions] | ||
|
||
return self.mkCombinedConstraintExpression(filters) |
Oops, something went wrong.