Skip to content

Commit

Permalink
Feature constraints (databrickslabs#257)
Browse files Browse the repository at this point in the history
* feature for constraints

* lint related updates

* work in progress

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* updated build actions

* Update CHANGELOG.md

* wip

* wip

* wip

* wip
  • Loading branch information
ronanstokes-db authored May 28, 2024
1 parent b28602d commit 82ce5ce
Show file tree
Hide file tree
Showing 20 changed files with 1,125 additions and 8 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ All notable changes to the Databricks Labs Data Generator will be documented in

### Unreleased

#### Changed
### Changed
* Modified data generator to allow specification of constraints to the data generation process
* Updated documentation for generating text data.

### Added
* Added classes for constraints on the data generation via new package `dbldatagen.constraints`


### Version 0.3.6 Post 1

Expand Down
40 changes: 40 additions & 0 deletions dbldatagen/constraints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This package defines the constraints classes for the `dbldatagen` library.
The constraints classes are used to define predefined constraints that may be used to constrain the generated data.
Constraining the generated data is implemented in several ways:
- Rejection of rows that do not meet the criteria
- Modifying the generated data to meet the constraint (including modifying the data generation parameters)
Some constraints may be implemented using a combination of the above approaches.
For implementations using the rejection approach, the data generation process will possibly generate less than the
requested number of rows.
For the current implementation, most of the constraint strategies will be implemented using rejection based criteria.
"""

from .chained_relation import ChainedRelation
from .constraint import Constraint
from .literal_range_constraint import LiteralRange
from .literal_relation_constraint import LiteralRelation
from .negative_values import NegativeValues
from .positive_values import PositiveValues
from .ranged_values_constraint import RangedValues
from .sql_expr import SqlExpr
from .unique_combinations import UniqueCombinations

__all__ = ["chained_relation",
"constraint",
"negative_values",
"literal_range_constraint",
"literal_relation_constraint",
"positive_values",
"ranged_values_constraint",
"unique_combinations"]
59 changes: 59 additions & 0 deletions dbldatagen/constraints/chained_relation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module defines the ChainedInequality class
"""
from pyspark.sql import DataFrame
import pyspark.sql.functions as F
from .constraint import Constraint, NoPrepareTransformMixin


class ChainedRelation(NoPrepareTransformMixin, Constraint):
"""ChainedRelation constraint
Constrains one or more columns so that each column has a relationship to the next.
For example if the constraint is defined as `ChainedRelation(['a', 'b','c'], "<")` then only rows that
satisfy the condition `a < b < c` will be included in the output
(where `a`, `b` and `c` represent the data values for the rows).
This can be used to model time related transactions (for example in retail where the purchaseDate, shippingDate
and returnDate all have a specific relationship) etc.
Relations supported include <, <=, >=, >, !=, ==
:param columns: column name or list of column names as string or list of strings
:param relation: operator to check - should be one of <,> , =,>=,<=, ==, !=
"""
def __init__(self, columns, relation):
super().__init__(supportsStreaming=True)
self._relation = relation
self._columns = self._columnsFromListOrString(columns)

if relation not in self.SUPPORTED_OPERATORS:
raise ValueError(f"Parameter `relation` should be one of the operators :{self.SUPPORTED_OPERATORS}")

if not isinstance(self._columns, list) or len(self._columns) <= 1:
raise ValueError("ChainedRelation constraints must be defined across more than one column")

def _generateFilterExpression(self):
""" Generated composite filter expression for chained set of filter expressions
I.e if columns is ['a', 'b', 'c'] and relation is '<'
create set of filters [ col('a') < col('b'), col('b') < col('c')]
and combine them as single expression using logical and operation
:return: filter expression for chained expressions
"""
expressions = [F.col(colname) for colname in self._columns]

filters = []
# build set of filters for chained expressions
for ix in range(1, len(expressions)):
filters.append(self._generate_relation_expression(expressions[ix - 1], self._relation, expressions[ix]))

# ... and combine them using logical `and` operation
return self.mkCombinedConstraintExpression(filters)
186 changes: 186 additions & 0 deletions dbldatagen/constraints/constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module defines the Constraint class
"""
import types
from abc import ABC, abstractmethod
from pyspark.sql import Column


class Constraint(ABC):
""" Constraint object - base class for predefined and custom constraints
This class is meant for internal use only.
"""
SUPPORTED_OPERATORS = ["<", ">", ">=", "!=", "==", "=", "<=", "<>"]

def __init__(self, supportsStreaming=False):
"""
Initialize the constraint object
"""
self._filterExpression = None
self._calculatedFilterExpression = False
self._supportsStreaming = supportsStreaming

@staticmethod
def _columnsFromListOrString(columns):
""" Get columns as list of columns from string of list-like
:param columns: string or list of strings representing column names
"""
if isinstance(columns, str):
return [columns]
elif isinstance(columns, (list, set, tuple, types.GeneratorType)):
return list(columns)
else:
raise ValueError("Columns must be a string or list of strings")

@staticmethod
def _generate_relation_expression(column, relation, valueExpression):
""" Generate comparison expression
:param column: Column to generate comparison against
:param relation: relation to implement
:param valueExpression: expression to compare to
:return: relation expression as variation of Pyspark SQL columns
"""
if relation == ">":
return column > valueExpression
elif relation == ">=":
return column >= valueExpression
elif relation == "<":
return column < valueExpression
elif relation == "<=":
return column <= valueExpression
elif relation in ["!=", "<>"]:
return column != valueExpression
elif relation in ["=", "=="]:
return column == valueExpression
else:
raise ValueError(f"Unsupported relation type '{relation}")

@staticmethod
def mkCombinedConstraintExpression(constraintExpressions):
""" Generate a SQL expression that combines multiple constraints using AND
:param constraintExpressions: list of Pyspark SQL Column constraint expression objects
:return: combined constraint expression as Pyspark SQL Column object (or None if no valid expressions)
"""
assert constraintExpressions is not None and isinstance(constraintExpressions, list), \
"Constraints must be a list of Pyspark SQL Column instances"

assert all(expr is None or isinstance(expr, Column) for expr in constraintExpressions), \
"Constraint expressions must be Pyspark SQL columns or None"

valid_constraint_expressions = [expr for expr in constraintExpressions if expr is not None]

if len(valid_constraint_expressions) > 0:
combined_constraint_expression = valid_constraint_expressions[0]

for additional_constraint in valid_constraint_expressions[1:]:
combined_constraint_expression = combined_constraint_expression & additional_constraint

return combined_constraint_expression
else:
return None

@abstractmethod
def prepareDataGenerator(self, dataGenerator):
""" Prepare the data generator to generate data that matches the constraint
This method may modify the data generation rules to meet the constraint
:param dataGenerator: Data generation object that will generate the dataframe
:return: modified or unmodified data generator
"""
raise NotImplementedError("Method prepareDataGenerator must be implemented in derived class")

@abstractmethod
def transformDataframe(self, dataGenerator, dataFrame):
""" Transform the dataframe to make data conform to constraint if possible
This method should not modify the dataGenerator - but may modify the dataframe
:param dataGenerator: Data generation object that generated the dataframe
:param dataFrame: generated dataframe
:return: modified or unmodified Spark dataframe
The default transformation returns the dataframe unmodified
"""
raise NotImplementedError("Method transformDataframe must be implemented in derived class")

@abstractmethod
def _generateFilterExpression(self):
""" Generate a Pyspark SQL expression that may be used for filtering"""
raise NotImplementedError("Method _generateFilterExpression must be implemented in derived class")

@property
def supportsStreaming(self):
""" Return True if the constraint supports streaming dataframes"""
return self._supportsStreaming

@property
def filterExpression(self):
""" Return the filter expression (as instance of type Column that evaluates to True or non-True)"""
if not self._calculatedFilterExpression:
self._filterExpression = self._generateFilterExpression()
self._calculatedFilterExpression = True
return self._filterExpression


class NoFilterMixin:
""" Mixin class to indicate that constraint has no filter expression
Intended to be used in implementation of the concrete constraint classes.
Use of the mixin class is optional but when used with the Constraint class and multiple inheritance,
it will provide a default implementation of the _generateFilterExpression method that satisfies
the abstract method requirement of the Constraint class.
When using mixins, place the mixin class first in the list of base classes.
"""
def _generateFilterExpression(self):
""" Generate a Pyspark SQL expression that may be used for filtering"""
return None


class NoPrepareTransformMixin:
""" Mixin class to indicate that constraint has no filter expression
Intended to be used in implementation of the concrete constraint classes.
Use of the mixin class is optional but when used with the Constraint class and multiple inheritance,
it will provide a default implementation of the `prepareDataGenerator` and `transformeDataFrame` methods
that satisfies the abstract method requirements of the Constraint class.
When using mixins, place the mixin class first in the list of base classes.
"""
def prepareDataGenerator(self, dataGenerator):
""" Prepare the data generator to generate data that matches the constraint
This method may modify the data generation rules to meet the constraint
:param dataGenerator: Data generation object that will generate the dataframe
:return: modified or unmodified data generator
"""
return dataGenerator

def transformDataframe(self, dataGenerator, dataFrame):
""" Transform the dataframe to make data conform to constraint if possible
This method should not modify the dataGenerator - but may modify the dataframe
:param dataGenerator: Data generation object that generated the dataframe
:param dataFrame: generated dataframe
:return: modified or unmodified Spark dataframe
The default transformation returns the dataframe unmodified
"""
return dataFrame
45 changes: 45 additions & 0 deletions dbldatagen/constraints/literal_range_constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module defines the ScalarRange class
"""
import pyspark.sql.functions as F

from .constraint import Constraint, NoPrepareTransformMixin


class LiteralRange(NoPrepareTransformMixin, Constraint):
""" LiteralRange Constraint object - validates that column value(s) are between 2 literal values
:param columns: Name of column or list of column names
:param lowValue: Tests that columns have values greater than low value (greater or equal if `strict` is False)
:param highValue: Tests that columns have values less than high value (less or equal if `strict` is False)
:param strict: If True, excludes low and high values from range. Defaults to False
Note `lowValue` and `highValue` must be values that can be converted to a literal expression using the
`pyspark.sql.functions.lit` function
"""

def __init__(self, columns, lowValue, highValue, strict=False):
super().__init__(supportsStreaming=True)
self._columns = self._columnsFromListOrString(columns)
self._lowValue = lowValue
self._highValue = highValue
self._strict = strict

def _generateFilterExpression(self):
""" Generate a SQL filter expression that may be used for filtering"""
expressions = [F.col(colname) for colname in self._columns]
minValue = F.lit(self._lowValue)
maxValue = F.lit(self._highValue)

# build ranged comparison expressions
if self._strict:
filters = [(column_expr > minValue) & (column_expr < maxValue) for column_expr in expressions]
else:
filters = [column_expr.between(minValue, maxValue) for column_expr in expressions]

# ... and combine them using logical `and` operation
return self.mkCombinedConstraintExpression(filters)
37 changes: 37 additions & 0 deletions dbldatagen/constraints/literal_relation_constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module defines the ScalarInequality class
"""
import pyspark.sql.functions as F

from .constraint import Constraint, NoPrepareTransformMixin


class LiteralRelation(NoPrepareTransformMixin, Constraint):
"""LiteralRelation constraint
Constrains one or more columns so that the columns have an a relationship to a constant value
:param columns: column name or list of column names
:param relation: operator to check - should be one of <,> , =,>=,<=, ==, !=
:param value: A literal value to to compare against
"""

def __init__(self, columns, relation, value):
super().__init__(supportsStreaming=True)
self._columns = self._columnsFromListOrString(columns)
self._relation = relation
self._value = value

if relation not in self.SUPPORTED_OPERATORS:
raise ValueError(f"Parameter `relation` should be one of the operators :{self.SUPPORTED_OPERATORS}")

def _generateFilterExpression(self):
expressions = [F.col(colname) for colname in self._columns]
literalValue = F.lit(self._value)
filters = [self._generate_relation_expression(col, self._relation, literalValue) for col in expressions]

return self.mkCombinedConstraintExpression(filters)
Loading

0 comments on commit 82ce5ce

Please sign in to comment.