Skip to content

Commit

Permalink
Python: Add positional deletes (apache#6775)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko authored Jun 20, 2023
1 parent 717b3d7 commit 9ffb762
Show file tree
Hide file tree
Showing 8 changed files with 704 additions and 199 deletions.
281 changes: 142 additions & 139 deletions python/poetry.lock

Large diffs are not rendered by default.

111 changes: 102 additions & 9 deletions python/pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import os
from abc import ABC, abstractmethod
from functools import lru_cache, singledispatch
from itertools import chain
from multiprocessing.pool import ThreadPool
from multiprocessing.sharedctypes import Synchronized
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Iterable,
List,
Expand All @@ -46,9 +48,11 @@
)
from urllib.parse import urlparse

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
from pyarrow import ChunkedArray
from pyarrow.fs import (
FileInfo,
FileSystem,
Expand Down Expand Up @@ -85,6 +89,7 @@
OutputFile,
OutputStream,
)
from pyiceberg.manifest import DataFile, FileFormat
from pyiceberg.schema import (
PartnerAccessor,
Schema,
Expand Down Expand Up @@ -498,6 +503,39 @@ def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
return boolean_expression_visit(expr, _ConvertToArrowExpression())


@lru_cache
def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat:
if file_format == FileFormat.PARQUET:
return ds.ParquetFileFormat(**kwargs)
else:
raise ValueError(f"Unsupported file format: {file_format}")


def _construct_fragment(fs: FileSystem, data_file: DataFile, file_format_kwargs: Dict[str, Any] = EMPTY_DICT) -> ds.Fragment:
_, path = PyArrowFileIO.parse_location(data_file.file_path)
return _get_file_format(data_file.file_format, **file_format_kwargs).make_fragment(path, fs)


def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedArray]:
delete_fragment = _construct_fragment(
fs, data_file, file_format_kwargs={"dictionary_columns": ("file_path",), "pre_buffer": True, "buffer_size": ONE_MEGABYTE}
)
table = ds.Scanner.from_fragment(fragment=delete_fragment).to_table()
table = table.unify_dictionaries()
return {
file.as_py(): table.filter(pc.field("file_path") == file).column("pos")
for file in table.column("file_path").chunks[0].dictionary
}


def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: int) -> pa.Array:
if len(positional_deletes) == 1:
all_chunks = positional_deletes[0]
else:
all_chunks = pa.chunked_array(chain(*[arr.chunks for arr in positional_deletes]))
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)


def pyarrow_to_schema(schema: pa.Schema) -> Schema:
visitor = _ConvertToIceberg()
return visit_pyarrow(schema, visitor)
Expand Down Expand Up @@ -682,12 +720,13 @@ def primitive(self, primitive: pa.DataType) -> IcebergType:
raise TypeError(f"Unsupported type: {primitive}")


def _file_to_table(
def _task_to_table(
fs: FileSystem,
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
rows_counter: Synchronized[int],
limit: Optional[int] = None,
Expand Down Expand Up @@ -721,18 +760,44 @@ def _file_to_table(
fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
schema=physical_schema,
filter=pyarrow_filter,
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
filter=pyarrow_filter if not positional_deletes else None,
columns=[col.name for col in file_project_schema.columns],
)

if positional_deletes:
# Create the mask of indices that we're interested in
indices = _combine_positional_deletes(positional_deletes, fragment.count_rows())

if limit:
if pyarrow_filter is not None:
# In case of the filter, we don't exactly know how many rows
# we need to fetch upfront, can be optimized in the future:
# https://github.com/apache/arrow/issues/35301
arrow_table = fragment_scanner.take(indices)
arrow_table = arrow_table.filter(pyarrow_filter)
arrow_table = arrow_table.slice(0, limit)
else:
arrow_table = fragment_scanner.take(indices[0:limit])
else:
arrow_table = fragment_scanner.take(indices)
# Apply the user filter
if pyarrow_filter is not None:
arrow_table = arrow_table.filter(pyarrow_filter)
else:
# If there are no deletes, we can just take the head
# and the user-filter is already applied
if limit:
arrow_table = fragment_scanner.head(limit)
else:
arrow_table = fragment_scanner.to_table()

if limit:
arrow_table = fragment_scanner.head(limit)
with rows_counter.get_lock():
if rows_counter.value >= limit:
return None
rows_counter.value += len(arrow_table)
else:
arrow_table = fragment_scanner.to_table()

# If there is no data, we don't have to go through the schema
if len(arrow_table) > 0:
Expand All @@ -741,12 +806,29 @@ def _file_to_table(
return None


def _read_all_delete_files(fs: FileSystem, pool: ThreadPool, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
deletes_per_file: Dict[str, List[ChunkedArray]] = {}
unique_deletes = set(chain.from_iterable([task.delete_files for task in tasks]))
if len(unique_deletes) > 0:
deletes_per_files: List[Dict[str, ChunkedArray]] = pool.starmap(
func=_read_deletes, iterable=[(fs, delete) for delete in unique_deletes]
)
for delete in deletes_per_files:
for file, arr in delete.items():
if file in deletes_per_file:
deletes_per_file[file].append(arr)
else:
deletes_per_file[file] = [arr]

return deletes_per_file


def project_table(
tasks: Iterable[FileScanTask],
table: Table,
row_filter: BooleanExpression,
projected_schema: Schema,
case_sensitive: bool,
case_sensitive: bool = True,
limit: Optional[int] = None,
) -> pa.Table:
"""Resolves the right columns based on the identifier.
Expand All @@ -757,6 +839,7 @@ def project_table(
row_filter (BooleanExpression): The expression for filtering rows.
projected_schema (Schema): The output schema.
case_sensitive (bool): Case sensitivity when looking up column names.
limit (Optional[int]): Limit the number of records.
Raises:
ResolveError: When an incompatible query is done.
Expand Down Expand Up @@ -785,15 +868,25 @@ def project_table(
rows_counter = multiprocessing.Value("i", 0)

with ThreadPool() as pool:
deletes_per_file = _read_all_delete_files(fs, pool, tasks)
tables = [
table
for table in pool.starmap(
func=_file_to_table,
func=_task_to_table,
iterable=[
(fs, task, bound_row_filter, projected_schema, projected_field_ids, case_sensitive, rows_counter, limit)
(
fs,
task,
bound_row_filter,
projected_schema,
projected_field_ids,
deletes_per_file.get(task.file.file_path),
case_sensitive,
rows_counter,
limit,
)
for task in tasks
],
chunksize=None, # we could use this to control how to materialize the generator of tasks (we should also make the expression above lazy)
)
if table is not None
]
Expand Down
16 changes: 16 additions & 0 deletions python/pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ def __setattr__(self, name: str, value: Any) -> None:
def __init__(self, *data: Any, **named_data: Any) -> None:
super().__init__(*data, **{"struct": DATA_FILE_TYPE, **named_data})

def __hash__(self) -> int:
return hash(self.file_path)

def __eq__(self, other: Any) -> bool:
return self.file_path == other.file_path if isinstance(other, DataFile) else False


MANIFEST_ENTRY_SCHEMA = Schema(
NestedField(0, "status", IntegerType(), required=True),
Expand Down Expand Up @@ -244,6 +250,10 @@ def __init__(self, *data: Any, **named_data: Any) -> None:
NestedField(519, "key_metadata", BinaryType(), required=False),
)

POSITIONAL_DELETE_SCHEMA = Schema(
NestedField(2147483546, "file_path", StringType()), NestedField(2147483545, "pos", IntegerType())
)


class ManifestFile(Record):
manifest_path: str
Expand All @@ -265,6 +275,12 @@ class ManifestFile(Record):
def __init__(self, *data: Any, **named_data: Any) -> None:
super().__init__(*data, **{"struct": MANIFEST_FILE_SCHEMA.as_struct(), **named_data})

def has_added_files(self) -> bool:
return self.added_files_count is None or self.added_files_count > 0

def has_existing_files(self) -> bool:
return self.existing_files_count is None or self.existing_files_count > 0

def fetch_manifest_entry(self, io: FileIO, discard_deleted: bool = True) -> List[ManifestEntry]:
"""
Reads the manifest entries from the manifest file.
Expand Down
Loading

0 comments on commit 9ffb762

Please sign in to comment.