From 0fba740cd557d62c86e84432c66411317f03b200 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Mon, 26 Apr 2021 13:15:39 +0100 Subject: [PATCH] [ArrowStringArray] implement ArrowStringArray._str_contains (#41025) --- asv_bench/benchmarks/strings.py | 15 +- pandas/core/arrays/string_arrow.py | 10 ++ pandas/tests/strings/test_find_replace.py | 184 +++++++++++++++------- 3 files changed, 151 insertions(+), 58 deletions(-) diff --git a/asv_bench/benchmarks/strings.py b/asv_bench/benchmarks/strings.py index 5d9b1c135d7ae..45a9053954569 100644 --- a/asv_bench/benchmarks/strings.py +++ b/asv_bench/benchmarks/strings.py @@ -213,13 +213,18 @@ def time_cat(self, other_cols, sep, na_rep, na_frac): class Contains: - params = [True, False] - param_names = ["regex"] + params = (["str", "string", "arrow_string"], [True, False]) + param_names = ["dtype", "regex"] + + def setup(self, dtype, regex): + from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401 - def setup(self, regex): - self.s = Series(tm.makeStringIndex(10 ** 5)) + try: + self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype) + except ImportError: + raise NotImplementedError - def time_contains(self, regex): + def time_contains(self, dtype, regex): self.s.str.contains("A", regex=regex) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 55cb350d3d27c..b7a0e70180ae4 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -759,6 +759,16 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None): # -> We don't know the result type. E.g. `.get` can return anything. return lib.map_infer_mask(arr, f, mask.view("uint8")) + def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex=True): + if not regex and case: + result = pc.match_substring(self._data, pat) + result = BooleanDtype().__from_arrow__(result) + if not isna(na): + result[isna(result)] = bool(na) + return result + else: + return super()._str_contains(pat, case, flags, na, regex) + def _str_isalnum(self): if hasattr(pc, "utf8_is_alnum"): result = pc.utf8_is_alnum(self._data) diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index ab95b2071ae10..d801d3457027f 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -4,6 +4,8 @@ import numpy as np import pytest +import pandas.util._test_decorators as td + import pandas as pd from pandas import ( Index, @@ -12,79 +14,118 @@ ) -def test_contains(): +@pytest.fixture( + params=[ + "object", + "string", + pytest.param( + "arrow_string", marks=td.skip_if_no("pyarrow", min_version="1.0.0") + ), + ] +) +def any_string_dtype(request): + """ + Parametrized fixture for string dtypes. + * 'object' + * 'string' + * 'arrow_string' + """ + from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401 + + return request.param + + +def test_contains(any_string_dtype): values = np.array( ["foo", np.nan, "fooommm__foo", "mmm_", "foommm[_]+bar"], dtype=np.object_ ) - values = Series(values) + values = Series(values, dtype=any_string_dtype) pat = "mmm[_]+" result = values.str.contains(pat) - expected = Series(np.array([False, np.nan, True, True, False], dtype=np.object_)) + expected_dtype = "object" if any_string_dtype == "object" else "boolean" + expected = Series( + np.array([False, np.nan, True, True, False], dtype=np.object_), + dtype=expected_dtype, + ) tm.assert_series_equal(result, expected) result = values.str.contains(pat, regex=False) - expected = Series(np.array([False, np.nan, False, False, True], dtype=np.object_)) + expected = Series( + np.array([False, np.nan, False, False, True], dtype=np.object_), + dtype=expected_dtype, + ) tm.assert_series_equal(result, expected) - values = Series(np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=object)) + values = Series( + np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=object), + dtype=any_string_dtype, + ) result = values.str.contains(pat) - expected = Series(np.array([False, False, True, True])) - assert result.dtype == np.bool_ + expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean" + expected = Series(np.array([False, False, True, True]), dtype=expected_dtype) tm.assert_series_equal(result, expected) # case insensitive using regex - values = Series(np.array(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dtype=object)) + values = Series( + np.array(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dtype=object), + dtype=any_string_dtype, + ) result = values.str.contains("FOO|mmm", case=False) - expected = Series(np.array([True, False, True, True])) + expected = Series(np.array([True, False, True, True]), dtype=expected_dtype) tm.assert_series_equal(result, expected) # case insensitive without regex - result = Series(values).str.contains("foo", regex=False, case=False) - expected = Series(np.array([True, False, True, False])) + result = values.str.contains("foo", regex=False, case=False) + expected = Series(np.array([True, False, True, False]), dtype=expected_dtype) tm.assert_series_equal(result, expected) - # mixed + # unicode + values = Series( + np.array(["foo", np.nan, "fooommm__foo", "mmm_"], dtype=np.object_), + dtype=any_string_dtype, + ) + pat = "mmm[_]+" + + result = values.str.contains(pat) + expected_dtype = "object" if any_string_dtype == "object" else "boolean" + expected = Series( + np.array([False, np.nan, True, True], dtype=np.object_), dtype=expected_dtype + ) + tm.assert_series_equal(result, expected) + + result = values.str.contains(pat, na=False) + expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean" + expected = Series(np.array([False, False, True, True]), dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + values = Series( + np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=np.object_), + dtype=any_string_dtype, + ) + result = values.str.contains(pat) + expected = Series(np.array([False, False, True, True]), dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + +def test_contains_object_mixed(): mixed = Series( np.array( ["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0], dtype=object, ) ) - rs = mixed.str.contains("o") - xp = Series( + result = mixed.str.contains("o") + expected = Series( np.array( [False, np.nan, False, np.nan, np.nan, True, np.nan, np.nan, np.nan], dtype=np.object_, ) ) - tm.assert_series_equal(rs, xp) - - rs = mixed.str.contains("o") - xp = Series([False, np.nan, False, np.nan, np.nan, True, np.nan, np.nan, np.nan]) - assert isinstance(rs, Series) - tm.assert_series_equal(rs, xp) - - # unicode - values = Series(np.array(["foo", np.nan, "fooommm__foo", "mmm_"], dtype=np.object_)) - pat = "mmm[_]+" - - result = values.str.contains(pat) - expected = Series(np.array([False, np.nan, True, True], dtype=np.object_)) - tm.assert_series_equal(result, expected) - - result = values.str.contains(pat, na=False) - expected = Series(np.array([False, False, True, True])) - tm.assert_series_equal(result, expected) - - values = Series(np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=np.object_)) - result = values.str.contains(pat) - expected = Series(np.array([False, False, True, True])) - assert result.dtype == np.bool_ tm.assert_series_equal(result, expected) -def test_contains_for_object_category(): +def test_contains_na_kwarg_for_object_category(): # gh 22158 # na for category @@ -108,6 +149,29 @@ def test_contains_for_object_category(): tm.assert_series_equal(result, expected) +@pytest.mark.parametrize( + "na, expected", + [ + (None, pd.NA), + (True, True), + (False, False), + (0, False), + (3, True), + (np.nan, pd.NA), + ], +) +@pytest.mark.parametrize("regex", [True, False]) +def test_contains_na_kwarg_for_nullable_string_dtype( + nullable_string_dtype, na, expected, regex +): + # https://github.com/pandas-dev/pandas/pull/41025#issuecomment-824062416 + + values = Series(["a", "b", "c", "a", np.nan], dtype=nullable_string_dtype) + result = values.str.contains("a", na=na, regex=regex) + expected = Series([True, False, False, True, expected], dtype="boolean") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("dtype", [None, "category"]) @pytest.mark.parametrize("null_value", [None, np.nan, pd.NA]) @pytest.mark.parametrize("na", [True, False]) @@ -508,59 +572,73 @@ def _check(result, expected): tm.assert_series_equal(result, expected) -def test_contains_moar(): +def test_contains_moar(any_string_dtype): # PR #1179 - s = Series(["A", "B", "C", "Aaba", "Baca", "", np.nan, "CABA", "dog", "cat"]) + s = Series( + ["A", "B", "C", "Aaba", "Baca", "", np.nan, "CABA", "dog", "cat"], + dtype=any_string_dtype, + ) result = s.str.contains("a") + expected_dtype = "object" if any_string_dtype == "object" else "boolean" expected = Series( - [False, False, False, True, True, False, np.nan, False, False, True] + [False, False, False, True, True, False, np.nan, False, False, True], + dtype=expected_dtype, ) tm.assert_series_equal(result, expected) result = s.str.contains("a", case=False) expected = Series( - [True, False, False, True, True, False, np.nan, True, False, True] + [True, False, False, True, True, False, np.nan, True, False, True], + dtype=expected_dtype, ) tm.assert_series_equal(result, expected) result = s.str.contains("Aa") expected = Series( - [False, False, False, True, False, False, np.nan, False, False, False] + [False, False, False, True, False, False, np.nan, False, False, False], + dtype=expected_dtype, ) tm.assert_series_equal(result, expected) result = s.str.contains("ba") expected = Series( - [False, False, False, True, False, False, np.nan, False, False, False] + [False, False, False, True, False, False, np.nan, False, False, False], + dtype=expected_dtype, ) tm.assert_series_equal(result, expected) result = s.str.contains("ba", case=False) expected = Series( - [False, False, False, True, True, False, np.nan, True, False, False] + [False, False, False, True, True, False, np.nan, True, False, False], + dtype=expected_dtype, ) tm.assert_series_equal(result, expected) -def test_contains_nan(): +def test_contains_nan(any_string_dtype): # PR #14171 - s = Series([np.nan, np.nan, np.nan], dtype=np.object_) + s = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype) result = s.str.contains("foo", na=False) - expected = Series([False, False, False], dtype=np.bool_) + expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean" + expected = Series([False, False, False], dtype=expected_dtype) tm.assert_series_equal(result, expected) result = s.str.contains("foo", na=True) - expected = Series([True, True, True], dtype=np.bool_) + expected = Series([True, True, True], dtype=expected_dtype) tm.assert_series_equal(result, expected) result = s.str.contains("foo", na="foo") - expected = Series(["foo", "foo", "foo"], dtype=np.object_) + if any_string_dtype == "object": + expected = Series(["foo", "foo", "foo"], dtype=np.object_) + else: + expected = Series([True, True, True], dtype="boolean") tm.assert_series_equal(result, expected) result = s.str.contains("foo") - expected = Series([np.nan, np.nan, np.nan], dtype=np.object_) + expected_dtype = "object" if any_string_dtype == "object" else "boolean" + expected = Series([np.nan, np.nan, np.nan], dtype=expected_dtype) tm.assert_series_equal(result, expected) @@ -609,14 +687,14 @@ def test_replace_moar(): tm.assert_series_equal(result, expected) -def test_match_findall_flags(): +def test_flags_kwarg(any_string_dtype): data = { "Dave": "dave@google.com", "Steve": "steve@gmail.com", "Rob": "rob@gmail.com", "Wes": np.nan, } - data = Series(data) + data = Series(data, dtype=any_string_dtype) pat = r"([A-Z0-9._%+-]+)@([A-Z0-9.-]+)\.([A-Z]{2,4})"