Skip to content

Commit

Permalink
Python: Add expression evaluator (apache#6127)
Browse files Browse the repository at this point in the history
  • Loading branch information
rdblue authored Nov 20, 2022
1 parent 7b5c64c commit 194b45f
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 30 deletions.
2 changes: 1 addition & 1 deletion python/pyiceberg/avro/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
from uuid import UUID

from pyiceberg.avro.decoder import BinaryDecoder
from pyiceberg.files import StructProtocol
from pyiceberg.schema import Schema, SchemaVisitor
from pyiceberg.typedef import StructProtocol
from pyiceberg.types import (
BinaryType,
BooleanType,
Expand Down
9 changes: 6 additions & 3 deletions python/pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from functools import reduce
from functools import cached_property, reduce
from typing import (
Any,
Generic,
Expand All @@ -29,9 +29,8 @@
)

from pyiceberg.expressions.literals import Literal, literal
from pyiceberg.files import StructProtocol
from pyiceberg.schema import Accessor, Schema
from pyiceberg.typedef import L
from pyiceberg.typedef import L, StructProtocol
from pyiceberg.types import DoubleType, FloatType, NestedField
from pyiceberg.utils.singleton import Singleton

Expand Down Expand Up @@ -459,6 +458,10 @@ def __init__(self, term: BoundTerm[L], literals: Set[Literal[L]]):
super().__init__(term) # type: ignore
self.literals = _to_literal_set(literals) # pylint: disable=W0621

@cached_property
def value_set(self) -> Set[L]:
return {lit.value for lit in self.literals}

def __str__(self):
# Sort to make it deterministic
return f"{str(self.__class__.__name__)}({str(self.term)}, {{{', '.join(sorted([str(literal) for literal in self.literals]))}}})"
Expand Down
88 changes: 80 additions & 8 deletions python/pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pyiceberg.manifest import ManifestFile, PartitionFieldSummary
from pyiceberg.schema import Schema
from pyiceberg.table import PartitionSpec
from pyiceberg.typedef import StructProtocol
from pyiceberg.types import (
DoubleType,
FloatType,
Expand Down Expand Up @@ -240,11 +241,11 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi

class BoundBooleanExpressionVisitor(BooleanExpressionVisitor[T], ABC):
@abstractmethod
def visit_in(self, term: BoundTerm[L], literals: Set[Literal[L]]) -> T:
def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> T:
"""Visit a bound In predicate"""

@abstractmethod
def visit_not_in(self, term: BoundTerm[L], literals: Set[Literal[L]]) -> T:
def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> T:
"""Visit a bound NotIn predicate"""

@abstractmethod
Expand Down Expand Up @@ -331,12 +332,12 @@ def visit_bound_predicate(expr: BoundPredicate[L], _: BooleanExpressionVisitor[T

@visit_bound_predicate.register(BoundIn)
def _(expr: BoundIn[L], visitor: BoundBooleanExpressionVisitor[T]) -> T:
return visitor.visit_in(term=expr.term, literals=expr.literals)
return visitor.visit_in(term=expr.term, literals=expr.value_set)


@visit_bound_predicate.register(BoundNotIn)
def _(expr: BoundNotIn[L], visitor: BoundBooleanExpressionVisitor[T]) -> T:
return visitor.visit_not_in(term=expr.term, literals=expr.literals)
return visitor.visit_not_in(term=expr.term, literals=expr.value_set)


@visit_bound_predicate.register(BoundIsNaN)
Expand Down Expand Up @@ -419,6 +420,77 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
return predicate


def expression_evaluator(
schema: Schema, unbound: BooleanExpression, case_sensitive: bool = True
) -> Callable[[StructProtocol], bool]:
return _ExpressionEvaluator(schema, unbound, case_sensitive).eval


class _ExpressionEvaluator(BoundBooleanExpressionVisitor[bool]):
bound: BooleanExpression
struct: StructProtocol

def __init__(self, schema: Schema, unbound: BooleanExpression, case_sensitive: bool = True):
self.bound = bind(schema, unbound, case_sensitive)

def eval(self, struct: StructProtocol) -> bool:
self.struct = struct
return visit(self.bound, self)

def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
return term.eval(self.struct) in literals

def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
return term.eval(self.struct) not in literals

def visit_is_nan(self, term: BoundTerm[L]) -> bool:
val = term.eval(self.struct)
return val != val

def visit_not_nan(self, term: BoundTerm[L]) -> bool:
val = term.eval(self.struct)
return val == val

def visit_is_null(self, term: BoundTerm[L]) -> bool:
return term.eval(self.struct) is None

def visit_not_null(self, term: BoundTerm[L]) -> bool:
return term.eval(self.struct) is not None

def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return term.eval(self.struct) == literal.value

def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return term.eval(self.struct) != literal.value

def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return term.eval(self.struct) >= literal.value

def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return term.eval(self.struct) > literal.value

def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return term.eval(self.struct) < literal.value

def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return term.eval(self.struct) <= literal.value

def visit_true(self) -> bool:
return True

def visit_false(self) -> bool:
return False

def visit_not(self, child_result: bool) -> bool:
return not child_result

def visit_and(self, left_result: bool, right_result: bool) -> bool:
return left_result and right_result

def visit_or(self, left_result: bool, right_result: bool) -> bool:
return left_result or right_result


ROWS_MIGHT_MATCH = True
ROWS_CANNOT_MATCH = False
IN_PREDICATE_LIMIT = 200
Expand All @@ -445,7 +517,7 @@ def eval(self, manifest: ManifestFile) -> bool:
# No partition information
return ROWS_MIGHT_MATCH

def visit_in(self, term: BoundTerm[L], literals: Set[Literal[L]]) -> bool:
def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
pos = term.ref().accessor.position
field = self.partition_fields[pos]

Expand All @@ -457,17 +529,17 @@ def visit_in(self, term: BoundTerm[L], literals: Set[Literal[L]]) -> bool:

lower = _from_byte_buffer(term.ref().field.field_type, field.lower_bound)

if all(lower > val.value for val in literals):
if all(lower > val for val in literals):
return ROWS_CANNOT_MATCH

if field.upper_bound is not None:
upper = _from_byte_buffer(term.ref().field.field_type, field.upper_bound)
if all(upper < val.value for val in literals):
if all(upper < val for val in literals):
return ROWS_CANNOT_MATCH

return ROWS_MIGHT_MATCH

def visit_not_in(self, term: BoundTerm[L], literals: Set[Literal[L]]) -> bool:
def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
# because the bounds are not necessarily a min or max value, this cannot be answered using
# them. notIn(col, {X, ...}) with (X, Y) doesn't guarantee that X is a value in col.
return ROWS_MIGHT_MATCH
Expand Down
15 changes: 0 additions & 15 deletions python/pyiceberg/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from abc import abstractmethod
from enum import Enum, auto
from typing import Any, Protocol, runtime_checkable


class FileContentType(Enum):
Expand All @@ -34,16 +32,3 @@ class FileFormat(Enum):
PARQUET = auto()
AVRO = auto()
METADATA = auto()


@runtime_checkable
class StructProtocol(Protocol): # pragma: no cover
"""A generic protocol used by accessors to get and set at positions of an object"""

@abstractmethod
def get(self, pos: int) -> Any:
...

@abstractmethod
def set(self, pos: int, value: Any) -> None:
...
2 changes: 1 addition & 1 deletion python/pyiceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from pydantic import Field, PrivateAttr

from pyiceberg.files import StructProtocol
from pyiceberg.typedef import StructProtocol
from pyiceberg.types import (
IcebergType,
ListType,
Expand Down
16 changes: 16 additions & 0 deletions python/pyiceberg/typedef.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from abc import abstractmethod
from decimal import Decimal
from typing import (
Any,
Dict,
Protocol,
Tuple,
TypeVar,
Union,
runtime_checkable,
)
from uuid import UUID

Expand All @@ -41,3 +44,16 @@ def update(self, *args: Any, **kwargs: Any) -> None:

# Represents the literal value
L = TypeVar("L", str, bool, int, float, bytes, UUID, Decimal, covariant=True)


@runtime_checkable
class StructProtocol(Protocol): # pragma: no cover
"""A generic protocol used by accessors to get and set at positions of an object"""

@abstractmethod
def get(self, pos: int) -> Any:
...

@abstractmethod
def set(self, pos: int, value: Any) -> None:
...
Loading

0 comments on commit 194b45f

Please sign in to comment.