Skip to content

Commit

Permalink
Python: Refactor Avro read path to use a partner visitor (apache#6506)
Browse files Browse the repository at this point in the history
  • Loading branch information
rdblue authored Jan 2, 2023
1 parent adecd8a commit cf00f6a
Show file tree
Hide file tree
Showing 9 changed files with 371 additions and 204 deletions.
32 changes: 22 additions & 10 deletions python/pyiceberg/avro/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,21 @@
from dataclasses import dataclass
from io import SEEK_SET, BufferedReader
from types import TracebackType
from typing import Optional, Type
from typing import (
Callable,
Dict,
Optional,
Type,
)

from pyiceberg.avro.codecs import KNOWN_CODECS, Codec
from pyiceberg.avro.decoder import BinaryDecoder
from pyiceberg.avro.reader import ConstructReader, Reader
from pyiceberg.avro.resolver import resolve
from pyiceberg.avro.reader import Reader
from pyiceberg.avro.resolver import construct_reader, resolve
from pyiceberg.io import InputFile, InputStream
from pyiceberg.io.memory import MemoryInputStream
from pyiceberg.schema import Schema, visit
from pyiceberg.typedef import Record
from pyiceberg.schema import Schema
from pyiceberg.typedef import EMPTY_DICT, Record, StructProtocol
from pyiceberg.types import (
FixedType,
MapType,
Expand Down Expand Up @@ -112,6 +117,7 @@ def __next__(self) -> Record:
class AvroFile:
input_file: InputFile
read_schema: Optional[Schema]
read_types: Dict[int, Callable[[Schema], StructProtocol]]
input_stream: InputStream
header: AvroFileHeader
schema: Schema
Expand All @@ -120,9 +126,15 @@ class AvroFile:
decoder: BinaryDecoder
block: Optional[Block] = None

def __init__(self, input_file: InputFile, read_schema: Optional[Schema] = None) -> None:
def __init__(
self,
input_file: InputFile,
read_schema: Optional[Schema] = None,
read_types: Dict[int, Callable[[Schema], StructProtocol]] = EMPTY_DICT,
) -> None:
self.input_file = input_file
self.read_schema = read_schema
self.read_types = read_types

def __enter__(self) -> AvroFile:
"""
Expand All @@ -137,9 +149,9 @@ def __enter__(self) -> AvroFile:
self.header = self._read_header()
self.schema = self.header.get_schema()
if not self.read_schema:
self.reader = visit(self.schema, ConstructReader())
else:
self.reader = resolve(self.schema, self.read_schema)
self.read_schema = self.schema

self.reader = resolve(self.schema, self.read_schema, self.read_types)

return self

Expand Down Expand Up @@ -184,6 +196,6 @@ def __next__(self) -> Record:

def _read_header(self) -> AvroFileHeader:
self.input_stream.seek(0, SEEK_SET)
reader = visit(META_SCHEMA, ConstructReader())
reader = construct_reader(META_SCHEMA)
_header = reader.read(self.decoder)
return AvroFileHeader(magic=_header.get(0), meta=_header.get(1), sync=_header.get(2))
115 changes: 25 additions & 90 deletions python/pyiceberg/avro/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,33 +37,11 @@
List,
Optional,
Tuple,
Union,
)
from uuid import UUID

from pyiceberg.avro.decoder import BinaryDecoder
from pyiceberg.schema import Schema, SchemaVisitorPerPrimitiveType
from pyiceberg.typedef import Record, StructProtocol
from pyiceberg.types import (
BinaryType,
BooleanType,
DateType,
DecimalType,
DoubleType,
FixedType,
FloatType,
IntegerType,
ListType,
LongType,
MapType,
NestedField,
StringType,
StructType,
TimestampType,
TimestamptzType,
TimeType,
UUIDType,
)
from pyiceberg.utils.singleton import Singleton


Expand Down Expand Up @@ -260,25 +238,43 @@ def skip(self, decoder: BinaryDecoder) -> None:
return self.option.skip(decoder)


@dataclass(frozen=True)
class StructReader(Reader):
fields: Tuple[Tuple[Optional[int], Reader], ...] = dataclassfield()
class StructProtocolReader(Reader):
create_struct: Callable[[], StructProtocol]
fields: Tuple[Tuple[Optional[int], Reader], ...]

def __init__(self, fields: Tuple[Tuple[Optional[int], Reader], ...], create_struct: Callable[[], StructProtocol]):
self.create_struct = create_struct
self.fields = fields

def create_or_reuse(self, reuse: Optional[StructProtocol]) -> StructProtocol:
if reuse:
return reuse
else:
return self.create_struct()

def read(self, decoder: BinaryDecoder) -> Any:
struct = self.create_or_reuse(None)

def read(self, decoder: BinaryDecoder) -> Record:
result: List[Union[Any, StructProtocol]] = [None] * len(self.fields)
for (pos, field) in self.fields:
if pos is not None:
result[pos] = field.read(decoder)
struct.set(pos, field.read(decoder)) # later: pass reuse in here
else:
field.skip(decoder)

return Record(*result)
return struct

def skip(self, decoder: BinaryDecoder) -> None:
for _, field in self.fields:
field.skip(decoder)


class StructReader(StructProtocolReader):
fields: Tuple[Tuple[Optional[int], Reader], ...]

def __init__(self, fields: Tuple[Tuple[Optional[int], Reader], ...]):
super().__init__(fields, lambda: Record.of(len(fields)))


@dataclass(frozen=True)
class ListReader(Reader):
element: Reader
Expand Down Expand Up @@ -325,64 +321,3 @@ def skip() -> None:
self.value.skip(decoder)

_skip_map_array(decoder, skip)


class ConstructReader(SchemaVisitorPerPrimitiveType[Reader]):
def schema(self, schema: Schema, struct_result: Reader) -> Reader:
return struct_result

def struct(self, struct: StructType, field_results: List[Reader]) -> Reader:
return StructReader(tuple(enumerate(field_results)))

def field(self, field: NestedField, field_result: Reader) -> Reader:
return field_result if field.required else OptionReader(field_result)

def list(self, list_type: ListType, element_result: Reader) -> Reader:
element_reader = element_result if list_type.element_required else OptionReader(element_result)
return ListReader(element_reader)

def map(self, map_type: MapType, key_result: Reader, value_result: Reader) -> Reader:
value_reader = value_result if map_type.value_required else OptionReader(value_result)
return MapReader(key_result, value_reader)

def visit_fixed(self, fixed_type: FixedType) -> Reader:
return FixedReader(len(fixed_type))

def visit_decimal(self, decimal_type: DecimalType) -> Reader:
return DecimalReader(decimal_type.precision, decimal_type.scale)

def visit_boolean(self, boolean_type: BooleanType) -> Reader:
return BooleanReader()

def visit_integer(self, integer_type: IntegerType) -> Reader:
return IntegerReader()

def visit_long(self, long_type: LongType) -> Reader:
return IntegerReader()

def visit_float(self, float_type: FloatType) -> Reader:
return FloatReader()

def visit_double(self, double_type: DoubleType) -> Reader:
return DoubleReader()

def visit_date(self, date_type: DateType) -> Reader:
return DateReader()

def visit_time(self, time_type: TimeType) -> Reader:
return TimeReader()

def visit_timestamp(self, timestamp_type: TimestampType) -> Reader:
return TimestampReader()

def visit_timestampz(self, timestamptz_type: TimestamptzType) -> Reader:
return TimestamptzReader()

def visit_string(self, string_type: StringType) -> Reader:
return StringReader()

def visit_uuid(self, uuid_type: UUIDType) -> Reader:
return UUIDReader()

def visit_binary(self, binary_ype: BinaryType) -> Reader:
return BinaryReader()
Loading

0 comments on commit cf00f6a

Please sign in to comment.