Skip to content

Commit

Permalink
Make util.objects.datatype classes not iterable
Browse files Browse the repository at this point in the history
Datatypes are typically used as structs not collections. This makes it tricky to assume iterating over them is a reasonable operation. In fact, often it is not.

This patch makes datatype classes not iterable. It does so by overriding `__iter__` and the methods from namedtuple's template that expect `self` to be iterable.

Additional adjustments:

- ensure `super.__eq__` value is propagated even if it is NotImplemented
- add is_iterable=True to FileContent, since it is used as an iterable

Testing Done:
Added tests that checked the behavior, then iterated on failures from the initial CI run. Current CI away in PR.

Bugs closed: 3790

Reviewed at https://rbcommons.com/s/twitter/r/4163/

closes pantsbuild#3790
  • Loading branch information
baroquebobcat committed Sep 6, 2016
1 parent dcdb55f commit 517fc38
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/python/pants/engine/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def parse_address_family(address_mapper, path, build_files_content):
if not build_files_content.dependencies:
raise ResolveError('Directory "{}" does not contain build files.'.format(path))
address_maps = []
for filepath, filecontent in build_files_content.dependencies:
address_maps.append(AddressMap.parse(filepath,
filecontent,
for filecontent_product in build_files_content.dependencies:
address_maps.append(AddressMap.parse(filecontent_product.path,
filecontent_product.content,
address_mapper.symbol_table_cls,
address_mapper.parser_cls,
address_mapper.exclude_patterns))
Expand Down
35 changes: 32 additions & 3 deletions src/python/pants/util/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,47 @@
from __future__ import (absolute_import, division, generators, nested_scopes, print_function,
unicode_literals, with_statement)

from collections import namedtuple
from collections import OrderedDict, namedtuple


def datatype(*args, **kwargs):
"""A wrapper for `namedtuple` that accounts for the type of the object in equality."""
"""A wrapper for `namedtuple` that accounts for the type of the object in equality.
"""
class DataType(namedtuple(*args, **kwargs)):
__slots__ = ()

def __eq__(self, other):
if self is other:
return True

# Compare types and fields.
return type(other) == type(self) and super(DataType, self).__eq__(other)
if type(self) != type(other):
return False
# Explicitly return super.__eq__'s value in case super returns NotImplemented
return super(DataType, self).__eq__(other)

def __ne__(self, other):
return not (self == other)

# NB: As datatype is not iterable, we need to override both __iter__ and all of the
# namedtuple methods that expect self to be iterable.
def __iter__(self):
raise TypeError("'{}' object is not iterable".format(type(self).__name__))

def _asdict(self):
'''Return a new OrderedDict which maps field names to their values'''
return OrderedDict(zip(self._fields, super(DataType, self).__iter__()))

def _replace(_self, **kwds):
'''Return a new datatype object replacing specified fields with new values'''
result = _self._make(map(kwds.pop, _self._fields, super(DataType, _self).__iter__()))
if kwds:
raise ValueError('Got unexpected field names: %r' % kwds.keys())
return result

def __getnewargs__(self):
'''Return self as a plain tuple. Used by copy and pickle.'''
return tuple(super(DataType, self).__iter__())

return DataType
4 changes: 2 additions & 2 deletions tests/python/pants_test/engine/examples/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def extract_scala_imports(source_files_content):
"""A toy example of dependency inference. Would usually be a compiler plugin."""
packages = set()
import_re = re.compile(r'^import ([^;]*);?$')
for _, content in source_files_content.dependencies:
for line in content.splitlines():
for filecontent in source_files_content.dependencies:
for line in filecontent.content.splitlines():
match = import_re.search(line)
if match:
packages.add(match.group(1).rsplit('.', 1)[0])
Expand Down
10 changes: 10 additions & 0 deletions tests/python/pants_test/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ python_tests(
]
)

python_tests(
name = 'objects',
sources = ['test_objects.py'],
coverage = ['pants.util.objects'],
dependencies = [
'3rdparty/python:mock',
'src/python/pants/util:objects',
]
)

python_tests(
name = 'osutil',
sources = ['test_osutil.py'],
Expand Down
122 changes: 122 additions & 0 deletions tests/python/pants_test/util/test_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# coding=utf-8
# Copyright 2016 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from __future__ import (absolute_import, division, generators, nested_scopes, print_function,
unicode_literals, with_statement)

import copy
import pickle
import unittest

from pants.util.objects import datatype


class ExportedDatatype(datatype('ExportedDatatype', ['val'])):
pass


class AbsClass(object):
pass


class ReturnsNotImplemented(object):
def __eq__(self, other):
return NotImplemented


class DatatypeTest(unittest.TestCase):

def test_eq_with_not_implemented_super(self):
class DatatypeSuperNotImpl(datatype('Foo', ['val']), ReturnsNotImplemented, tuple):
pass

self.assertNotEqual(DatatypeSuperNotImpl(1), DatatypeSuperNotImpl(1))

def test_type_included_in_eq(self):
foo = datatype('Foo', ['val'])
bar = datatype('Bar', ['val'])

self.assertFalse(foo(1) == bar(1))
self.assertTrue(foo(1) != bar(1))

def test_subclasses_not_equal(self):
foo = datatype('Foo', ['val'])
class Bar(foo):
pass

self.assertFalse(foo(1) == Bar(1))
self.assertTrue(foo(1) != Bar(1))

def test_repr(self):
bar = datatype('Bar', ['val', 'zal'])
self.assertEqual('Bar(val=1, zal=1)', repr(bar(1, 1)))

class Foo(datatype('F', ['val']), AbsClass):
pass

# Maybe this should be 'Foo(val=1)'?
self.assertEqual('F(val=1)', repr(Foo(1)))

def test_not_iterable(self):
bar = datatype('Bar', ['val'])
with self.assertRaises(TypeError):
for x in bar(1):
pass

def test_deep_copy(self):
# deep copy calls into __getnewargs__, which namedtuple defines as implicitly using __iter__.

bar = datatype('Bar', ['val'])

self.assertEqual(bar(1), copy.deepcopy(bar(1)))

def test_atrs(self):
bar = datatype('Bar', ['val'])
self.assertEqual(1, bar(1).val)

def test_as_dict(self):
bar = datatype('Bar', ['val'])

self.assertEqual({'val': 1}, bar(1)._asdict())

def test_replace_non_iterable(self):
bar = datatype('Bar', ['val', 'zal'])

self.assertEqual(bar(1, 3), bar(1, 2)._replace(zal=3))

def test_properties_not_assignable(self):
bar = datatype('Bar', ['val'])
bar_inst = bar(1)
with self.assertRaises(AttributeError):
bar_inst.val = 2

def test_invalid_field_name(self):
with self.assertRaises(ValueError):
datatype('Bar', ['0isntanallowedfirstchar'])

def test_subclass_pickleable(self):
before = ExportedDatatype(1)
dumps = pickle.dumps(before, protocol=2)
after = pickle.loads(dumps)
self.assertEqual(before, after)

def test_mixed_argument_types(self):
bar = datatype('Bar', ['val', 'zal'])
self.assertEqual(bar(1, 2), bar(val=1, zal=2))
self.assertEqual(bar(1, 2), bar(zal=2, val=1))

def test_double_passed_arg(self):
bar = datatype('Bar', ['val', 'zal'])
with self.assertRaises(TypeError):
bar(1, val=1)

def test_too_many_args(self):
bar = datatype('Bar', ['val', 'zal'])
with self.assertRaises(TypeError):
bar(1, 1, 1)

def test_unexpect_kwarg(self):
bar = datatype('Bar', ['val'])
with self.assertRaises(TypeError):
bar(other=1)

0 comments on commit 517fc38

Please sign in to comment.