From 3fee4d6e6c431892922b14dd11c391c6b14696f8 Mon Sep 17 00:00:00 2001 From: deepankar Date: Thu, 20 Apr 2017 04:10:58 +0530 Subject: [PATCH] implements intersection --- python/common/org/python/types/FrozenSet.java | 24 +++++++++++++++ tests/datatypes/test_frozenset.py | 30 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/python/common/org/python/types/FrozenSet.java b/python/common/org/python/types/FrozenSet.java index 620df1a299..ba665d5c3f 100644 --- a/python/common/org/python/types/FrozenSet.java +++ b/python/common/org/python/types/FrozenSet.java @@ -410,4 +410,28 @@ public org.python.Object union(org.python.types.Tuple others) { } return new org.python.types.FrozenSet(set); } + + @org.python.Method( + __doc__ = "Return the intersection of two sets as a new set.\n\n(i.e. all elements that are in both sets.)", + varargs = "others" + ) + public org.python.Object intersection(org.python.types.Tuple others) { + java.util.Set set = new java.util.HashSet(this.value); + for (org.python.Object other: others.value) { + try { + java.util.Set otherSet = null; + if (other instanceof org.python.types.Set) { + otherSet = ((org.python.types.Set) other).value; + } else if (other instanceof org.python.types.FrozenSet) { + otherSet = ((org.python.types.FrozenSet) other).value; + } else { + otherSet = iterToSet(other); + } + set.retainAll(otherSet); + } catch (org.python.exceptions.AttributeError e) { + throw new org.python.exceptions.TypeError("'" + other.typeName() + "' object is not iterable"); + } + } + return new org.python.types.FrozenSet(set); + } } diff --git a/tests/datatypes/test_frozenset.py b/tests/datatypes/test_frozenset.py index 5c6e6eff3d..1411fd918c 100644 --- a/tests/datatypes/test_frozenset.py +++ b/tests/datatypes/test_frozenset.py @@ -222,9 +222,13 @@ def test_union(self): y = frozenset({3, 4, 5}) z = [5, 6, 7] w = 1 + t = frozenset() print(x.union(y)) + # empty set test + print(x.union(t)) + # multiple args test print(x.union(y, z)) @@ -238,6 +242,32 @@ def test_union(self): print(err) """) + def test_intersection(self): + self.assertCodeExecution(""" + x = frozenset({1, 2, 3}) + y = frozenset({2, 3, 4}) + z = [3, 6, 7] + w = 1 + t = frozenset() + + print(x.intersection(y)) + + # empty set test + print(x.intersection(t)) + + # multiple args test + print(x.intersection(y, z)) + + # iterable test + print(x.intersection(z)) + + # not-iterable test + try: + print(x.intersection(w)) + except TypeError as err: + print(err) + """) + class UnaryFrozensetOperationTests(UnaryOperationTestCase, TranspileTestCase): data_type = 'frozenset'