Skip to content

Commit

Permalink
Python: Refactor expression hierarchy (apache#5389)
Browse files Browse the repository at this point in the history
* Convert And, Or, and Not to dataclasses.

* Refactor base expression types.
  • Loading branch information
rdblue authored Aug 1, 2022
1 parent aaa67d0 commit 5f2ce6e
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 96 deletions.
139 changes: 49 additions & 90 deletions python/pyiceberg/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,50 +32,43 @@


class BooleanExpression(ABC):
"""Represents a boolean expression tree."""
"""An expression that evaluates to a boolean"""

@abstractmethod
def __invert__(self) -> BooleanExpression:
"""Transform the Expression into its negated version."""


class Bound(Generic[T], ABC):
"""Represents a bound value expression."""
class Term(Generic[T], ABC):
"""A simple expression that evaluates to a value"""

def eval(self, struct: StructProtocol): # pylint: disable=W0613
... # pragma: no cover

class Bound(ABC):
"""Represents a bound value expression"""


class Unbound(Generic[T, B], ABC):
"""Represents an unbound expression node."""
class Unbound(Generic[B], ABC):
"""Represents an unbound value expression"""

@abstractmethod
def bind(self, schema: Schema, case_sensitive: bool) -> B:
def bind(self, schema: Schema, case_sensitive: bool = True) -> B:
... # pragma: no cover


class Term(ABC):
"""An expression that evaluates to a value."""


class BaseReference(Generic[T], Term, ABC):
"""Represents a variable reference in an expression."""


class BoundTerm(Bound[T], Term):
"""Represents a bound term."""
class BoundTerm(Term[T], Bound, ABC):
"""Represents a bound term"""

@abstractmethod
def ref(self) -> BoundReference[T]:
...


class UnboundTerm(Unbound[T, BoundTerm[T]], Term):
"""Represents an unbound term."""
@abstractmethod
def eval(self, struct: StructProtocol): # pylint: disable=W0613
... # pragma: no cover


@dataclass(frozen=True)
class BoundReference(BoundTerm[T], BaseReference[T]):
class BoundReference(BoundTerm[T]):
"""A reference bound to a field in a schema
Args:
Expand All @@ -88,6 +81,7 @@ class BoundReference(BoundTerm[T], BaseReference[T]):

def eval(self, struct: StructProtocol) -> T:
"""Returns the value at the referenced field's position in an object that abides by the StructProtocol
Args:
struct (StructProtocol): A row object that abides by the StructProtocol and returns values given a position
Returns:
Expand All @@ -99,8 +93,12 @@ def ref(self) -> BoundReference[T]:
return self


class UnboundTerm(Term[T], Unbound[BoundTerm[T]], ABC):
"""Represents an unbound term."""


@dataclass(frozen=True)
class Reference(UnboundTerm[T], BaseReference[T]):
class Reference(UnboundTerm[T]):
"""A reference not yet bound to a field in a schema
Args:
Expand All @@ -112,7 +110,7 @@ class Reference(UnboundTerm[T], BaseReference[T]):

name: str

def bind(self, schema: Schema, case_sensitive: bool) -> BoundReference[T]:
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundReference[T]:
"""Bind the reference to an Iceberg schema
Args:
Expand All @@ -125,22 +123,24 @@ def bind(self, schema: Schema, case_sensitive: bool) -> BoundReference[T]:
Returns:
BoundReference: A reference bound to the specific field in the Iceberg schema
"""
field = schema.find_field(name_or_id=self.name, case_sensitive=case_sensitive) # pylint: disable=redefined-outer-name

field = schema.find_field(name_or_id=self.name, case_sensitive=case_sensitive)
if not field:
raise ValueError(f"Cannot find field '{self.name}' in schema: {schema}")

accessor = schema.accessor_for_field(field.field_id)

if not accessor:
raise ValueError(f"Cannot find accessor for field '{self.name}' in schema: {schema}")

return BoundReference(field=field, accessor=accessor)


@dataclass(frozen=True, init=False)
class And(BooleanExpression):
"""AND operation expression - logical conjunction"""

left: BooleanExpression
right: BooleanExpression

def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression):
if rest:
return reduce(And, (left, right, *rest))
Expand All @@ -150,35 +150,23 @@ def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: Boole
return right
elif right is AlwaysTrue():
return left
self = super().__new__(cls)
self._left = left # type: ignore
self._right = right # type: ignore
return self

@property
def left(self) -> BooleanExpression:
return self._left # type: ignore

@property
def right(self) -> BooleanExpression:
return self._right # type: ignore

def __eq__(self, other) -> bool:
return id(self) == id(other) or (isinstance(other, And) and self.left == other.left and self.right == other.right)
else:
result = super().__new__(cls)
object.__setattr__(result, "left", left)
object.__setattr__(result, "right", right)
return result

def __invert__(self) -> Or:
return Or(~self.left, ~self.right)

def __repr__(self) -> str:
return f"And({repr(self.left)}, {repr(self.right)})"

def __str__(self) -> str:
return f"And({str(self.left)}, {str(self.right)})"


@dataclass(frozen=True, init=False)
class Or(BooleanExpression):
"""OR operation expression - logical disjunction"""

left: BooleanExpression
right: BooleanExpression

def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression):
if rest:
return reduce(Or, (left, right, *rest))
Expand All @@ -188,59 +176,36 @@ def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: Boole
return right
elif right is AlwaysFalse():
return left
self = super().__new__(cls)
self._left = left # type: ignore
self._right = right # type: ignore
return self

@property
def left(self) -> BooleanExpression:
return self._left # type: ignore

@property
def right(self) -> BooleanExpression:
return self._right # type: ignore

def __eq__(self, other) -> bool:
return id(self) == id(other) or (isinstance(other, Or) and self.left == other.left and self.right == other.right)
else:
result = super().__new__(cls)
object.__setattr__(result, "left", left)
object.__setattr__(result, "right", right)
return result

def __invert__(self) -> And:
return And(~self.left, ~self.right)

def __repr__(self) -> str:
return f"Or({repr(self.left)}, {repr(self.right)})"

def __str__(self) -> str:
return f"Or({str(self.left)}, {str(self.right)})"


@dataclass(frozen=True, init=False)
class Not(BooleanExpression):
"""NOT operation expression - logical negation"""

child: BooleanExpression

def __new__(cls, child: BooleanExpression):
if child is AlwaysTrue():
return AlwaysFalse()
elif child is AlwaysFalse():
return AlwaysTrue()
elif isinstance(child, Not):
return child.child
return super().__new__(cls)

def __init__(self, child):
self.child = child

def __eq__(self, other) -> bool:
return id(self) == id(other) or (isinstance(other, Not) and self.child == other.child)
result = super().__new__(cls)
object.__setattr__(result, "child", child)
return result

def __invert__(self) -> BooleanExpression:
return self.child

def __repr__(self) -> str:
return f"Not({repr(self.child)})"

def __str__(self) -> str:
return f"Not({str(self.child)})"


@dataclass(frozen=True)
class AlwaysTrue(BooleanExpression, Singleton):
Expand All @@ -259,15 +224,15 @@ def __invert__(self) -> AlwaysTrue:


@dataclass(frozen=True)
class BoundPredicate(Bound[T], BooleanExpression):
class BoundPredicate(Generic[T], Bound, BooleanExpression):
term: BoundTerm[T]

def __invert__(self) -> BoundPredicate[T]:
raise NotImplementedError


@dataclass(frozen=True)
class UnboundPredicate(Unbound[T, BooleanExpression], BooleanExpression):
class UnboundPredicate(Generic[T], Unbound[BooleanExpression], BooleanExpression):
as_bound: ClassVar[type]
term: UnboundTerm[T]

Expand Down Expand Up @@ -661,12 +626,6 @@ def _(obj: And, visitor: BooleanExpressionVisitor[T]) -> T:
return visitor.visit_and(left_result=left_result, right_result=right_result)


@visit.register(In)
def _(obj: In, visitor: BooleanExpressionVisitor[T]) -> T:
"""Visit an In boolean expression with a concrete BooleanExpressionVisitor"""
return visitor.visit_unbound_predicate(predicate=obj)


@visit.register(UnboundPredicate)
def _(obj: UnboundPredicate, visitor: BooleanExpressionVisitor[T]) -> T:
"""Visit an In boolean expression with a concrete BooleanExpressionVisitor"""
Expand Down
12 changes: 6 additions & 6 deletions python/tests/expressions/test_expressions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def _(obj: ExpressionB, visitor: BooleanExpressionVisitor) -> List:
[
(
base.And(ExpressionA(), ExpressionB()),
"And(ExpressionA(), ExpressionB())",
"And(left=ExpressionA(), right=ExpressionB())",
),
(
base.Or(ExpressionA(), ExpressionB()),
"Or(ExpressionA(), ExpressionB())",
"Or(left=ExpressionA(), right=ExpressionB())",
),
(base.Not(ExpressionA()), "Not(ExpressionA())"),
(base.Not(ExpressionA()), "Not(child=ExpressionA())"),
],
)
def test_reprs(op, rep):
Expand Down Expand Up @@ -208,9 +208,9 @@ def test_notnan_bind_nonfloat():
@pytest.mark.parametrize(
"op, string",
[
(base.And(ExpressionA(), ExpressionB()), "And(testexpra, testexprb)"),
(base.Or(ExpressionA(), ExpressionB()), "Or(testexpra, testexprb)"),
(base.Not(ExpressionA()), "Not(testexpra)"),
(base.And(ExpressionA(), ExpressionB()), "And(left=ExpressionA(), right=ExpressionB())"),
(base.Or(ExpressionA(), ExpressionB()), "Or(left=ExpressionA(), right=ExpressionB())"),
(base.Not(ExpressionA()), "Not(child=ExpressionA())"),
],
)
def test_strs(op, string):
Expand Down

0 comments on commit 5f2ce6e

Please sign in to comment.