Skip to content

Commit

Permalink
[BEAM-9477] RowCoder should be hashable and picklable (apache#11088)
Browse files Browse the repository at this point in the history
* Add (failing) test

* implement RowCoder.__hash__

* Add tests that require RowCoder to be picklable

* Fix pickling
  • Loading branch information
TheNeuralBit authored Mar 13, 2020
1 parent eb59dde commit 33ec0bb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
18 changes: 15 additions & 3 deletions sdks/python/apache_beam/coders/row_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from apache_beam.portability.api import schema_pb2
from apache_beam.typehints.schemas import named_tuple_from_schema
from apache_beam.typehints.schemas import named_tuple_to_schema
from apache_beam.utils import proto_utils

__all__ = ["RowCoder"]

Expand Down Expand Up @@ -69,7 +70,8 @@ def to_type_hint(self):
def as_cloud_object(self, coders_context=None):
raise NotImplementedError("as_cloud_object not supported for RowCoder")

__hash__ = None # type: ignore[assignment]
def __hash__(self):
return hash(self.schema.SerializeToString())

def __eq__(self, other):
return type(self) == type(other) and self.schema == other.schema
Expand All @@ -79,13 +81,18 @@ def to_runner_api_parameter(self, unused_context):

@staticmethod
@Coder.register_urn(common_urns.coders.ROW.urn, schema_pb2.Schema)
def from_runner_api_parameter(payload, components, unused_context):
return RowCoder(payload)
def from_runner_api_parameter(schema, components, unused_context):
return RowCoder(schema)

@staticmethod
def from_type_hint(named_tuple_type, registry):
return RowCoder(named_tuple_to_schema(named_tuple_type))

@staticmethod
def from_payload(payload):
# type: (bytes) -> RowCoder
return RowCoder(proto_utils.parse_Bytes(payload, schema_pb2.Schema))

@staticmethod
def coder_from_type(field_type):
type_info = field_type.WhichOneof("type_info")
Expand All @@ -106,6 +113,11 @@ def coder_from_type(field_type):
"Encountered a type that is not currently supported by RowCoder: %s" %
field_type)

def __reduce__(self):
# when pickling, use bytes representation of the schema. schema_pb2.Schema
# objects cannot be pickled.
return (RowCoder.from_payload, (self.schema.SerializeToString(), ))


class RowCoderImpl(StreamCoderImpl):
"""For internal use only; no backwards-compatibility guarantees."""
Expand Down
23 changes: 22 additions & 1 deletion sdks/python/apache_beam/coders/row_coder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@
import numpy as np
from past.builtins import unicode

import apache_beam as beam
from apache_beam.coders import RowCoder
from apache_beam.coders.typecoders import registry as coders_registry
from apache_beam.internal import pickler
from apache_beam.portability.api import schema_pb2
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.typehints.schemas import typing_to_runner_api

Person = typing.NamedTuple(
Expand All @@ -44,8 +49,9 @@


class RowCoderTest(unittest.TestCase):
TEST_CASE = Person("Jon Snow", 23, None, ["crow", "wildling"])
TEST_CASES = [
Person("Jon Snow", 23, None, ["crow", "wildling"]),
TEST_CASE,
Person("Daenerys Targaryen", 25, "Westeros", ["Mother of Dragons"]),
Person("Michael Bluth", 30, None, [])
]
Expand Down Expand Up @@ -165,6 +171,21 @@ def test_schema_add_column_with_null_value(self):
New(None, "baz", None),
new_coder.decode(old_coder.encode(Old(None, "baz"))))

def test_row_coder_picklable(self):
# occasionally coders can get pickled, RowCoder should be able to handle it
coder = coders_registry.get_coder(Person)
roundtripped = pickler.loads(pickler.dumps(coder))

self.assertEqual(roundtripped, coder)

def test_row_coder_in_pipeine(self):
with TestPipeline() as p:
res = (
p
| beam.Create(self.TEST_CASES)
| beam.Filter(lambda person: person.name == "Jon Snow"))
assert_that(res, equal_to([self.TEST_CASE]))


if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
Expand Down

0 comments on commit 33ec0bb

Please sign in to comment.