Skip to content

Commit

Permalink
[ArrowStringArray] implement ArrowStringArray._str_contains (pandas-d…
Browse files Browse the repository at this point in the history
  • Loading branch information
simonjayhawkins authored Apr 26, 2021
1 parent 8de6276 commit 0fba740
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 58 deletions.
15 changes: 10 additions & 5 deletions asv_bench/benchmarks/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,18 @@ def time_cat(self, other_cols, sep, na_rep, na_frac):

class Contains:

params = [True, False]
param_names = ["regex"]
params = (["str", "string", "arrow_string"], [True, False])
param_names = ["dtype", "regex"]

def setup(self, dtype, regex):
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401

def setup(self, regex):
self.s = Series(tm.makeStringIndex(10 ** 5))
try:
self.s = Series(tm.makeStringIndex(10 ** 5), dtype=dtype)
except ImportError:
raise NotImplementedError

def time_contains(self, regex):
def time_contains(self, dtype, regex):
self.s.str.contains("A", regex=regex)


Expand Down
10 changes: 10 additions & 0 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,16 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex=True):
if not regex and case:
result = pc.match_substring(self._data, pat)
result = BooleanDtype().__from_arrow__(result)
if not isna(na):
result[isna(result)] = bool(na)
return result
else:
return super()._str_contains(pat, case, flags, na, regex)

def _str_isalnum(self):
if hasattr(pc, "utf8_is_alnum"):
result = pc.utf8_is_alnum(self._data)
Expand Down
184 changes: 131 additions & 53 deletions pandas/tests/strings/test_find_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

import pandas as pd
from pandas import (
Index,
Expand All @@ -12,79 +14,118 @@
)


def test_contains():
@pytest.fixture(
params=[
"object",
"string",
pytest.param(
"arrow_string", marks=td.skip_if_no("pyarrow", min_version="1.0.0")
),
]
)
def any_string_dtype(request):
"""
Parametrized fixture for string dtypes.
* 'object'
* 'string'
* 'arrow_string'
"""
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401

return request.param


def test_contains(any_string_dtype):
values = np.array(
["foo", np.nan, "fooommm__foo", "mmm_", "foommm[_]+bar"], dtype=np.object_
)
values = Series(values)
values = Series(values, dtype=any_string_dtype)
pat = "mmm[_]+"

result = values.str.contains(pat)
expected = Series(np.array([False, np.nan, True, True, False], dtype=np.object_))
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected = Series(
np.array([False, np.nan, True, True, False], dtype=np.object_),
dtype=expected_dtype,
)
tm.assert_series_equal(result, expected)

result = values.str.contains(pat, regex=False)
expected = Series(np.array([False, np.nan, False, False, True], dtype=np.object_))
expected = Series(
np.array([False, np.nan, False, False, True], dtype=np.object_),
dtype=expected_dtype,
)
tm.assert_series_equal(result, expected)

values = Series(np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=object))
values = Series(
np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=object),
dtype=any_string_dtype,
)
result = values.str.contains(pat)
expected = Series(np.array([False, False, True, True]))
assert result.dtype == np.bool_
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
tm.assert_series_equal(result, expected)

# case insensitive using regex
values = Series(np.array(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dtype=object))
values = Series(
np.array(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dtype=object),
dtype=any_string_dtype,
)
result = values.str.contains("FOO|mmm", case=False)
expected = Series(np.array([True, False, True, True]))
expected = Series(np.array([True, False, True, True]), dtype=expected_dtype)
tm.assert_series_equal(result, expected)

# case insensitive without regex
result = Series(values).str.contains("foo", regex=False, case=False)
expected = Series(np.array([True, False, True, False]))
result = values.str.contains("foo", regex=False, case=False)
expected = Series(np.array([True, False, True, False]), dtype=expected_dtype)
tm.assert_series_equal(result, expected)

# mixed
# unicode
values = Series(
np.array(["foo", np.nan, "fooommm__foo", "mmm_"], dtype=np.object_),
dtype=any_string_dtype,
)
pat = "mmm[_]+"

result = values.str.contains(pat)
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected = Series(
np.array([False, np.nan, True, True], dtype=np.object_), dtype=expected_dtype
)
tm.assert_series_equal(result, expected)

result = values.str.contains(pat, na=False)
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
tm.assert_series_equal(result, expected)

values = Series(
np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=np.object_),
dtype=any_string_dtype,
)
result = values.str.contains(pat)
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
tm.assert_series_equal(result, expected)


def test_contains_object_mixed():
mixed = Series(
np.array(
["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0],
dtype=object,
)
)
rs = mixed.str.contains("o")
xp = Series(
result = mixed.str.contains("o")
expected = Series(
np.array(
[False, np.nan, False, np.nan, np.nan, True, np.nan, np.nan, np.nan],
dtype=np.object_,
)
)
tm.assert_series_equal(rs, xp)

rs = mixed.str.contains("o")
xp = Series([False, np.nan, False, np.nan, np.nan, True, np.nan, np.nan, np.nan])
assert isinstance(rs, Series)
tm.assert_series_equal(rs, xp)

# unicode
values = Series(np.array(["foo", np.nan, "fooommm__foo", "mmm_"], dtype=np.object_))
pat = "mmm[_]+"

result = values.str.contains(pat)
expected = Series(np.array([False, np.nan, True, True], dtype=np.object_))
tm.assert_series_equal(result, expected)

result = values.str.contains(pat, na=False)
expected = Series(np.array([False, False, True, True]))
tm.assert_series_equal(result, expected)

values = Series(np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=np.object_))
result = values.str.contains(pat)
expected = Series(np.array([False, False, True, True]))
assert result.dtype == np.bool_
tm.assert_series_equal(result, expected)


def test_contains_for_object_category():
def test_contains_na_kwarg_for_object_category():
# gh 22158

# na for category
Expand All @@ -108,6 +149,29 @@ def test_contains_for_object_category():
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
"na, expected",
[
(None, pd.NA),
(True, True),
(False, False),
(0, False),
(3, True),
(np.nan, pd.NA),
],
)
@pytest.mark.parametrize("regex", [True, False])
def test_contains_na_kwarg_for_nullable_string_dtype(
nullable_string_dtype, na, expected, regex
):
# https://github.com/pandas-dev/pandas/pull/41025#issuecomment-824062416

values = Series(["a", "b", "c", "a", np.nan], dtype=nullable_string_dtype)
result = values.str.contains("a", na=na, regex=regex)
expected = Series([True, False, False, True, expected], dtype="boolean")
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("dtype", [None, "category"])
@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
@pytest.mark.parametrize("na", [True, False])
Expand Down Expand Up @@ -508,59 +572,73 @@ def _check(result, expected):
tm.assert_series_equal(result, expected)


def test_contains_moar():
def test_contains_moar(any_string_dtype):
# PR #1179
s = Series(["A", "B", "C", "Aaba", "Baca", "", np.nan, "CABA", "dog", "cat"])
s = Series(
["A", "B", "C", "Aaba", "Baca", "", np.nan, "CABA", "dog", "cat"],
dtype=any_string_dtype,
)

result = s.str.contains("a")
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected = Series(
[False, False, False, True, True, False, np.nan, False, False, True]
[False, False, False, True, True, False, np.nan, False, False, True],
dtype=expected_dtype,
)
tm.assert_series_equal(result, expected)

result = s.str.contains("a", case=False)
expected = Series(
[True, False, False, True, True, False, np.nan, True, False, True]
[True, False, False, True, True, False, np.nan, True, False, True],
dtype=expected_dtype,
)
tm.assert_series_equal(result, expected)

result = s.str.contains("Aa")
expected = Series(
[False, False, False, True, False, False, np.nan, False, False, False]
[False, False, False, True, False, False, np.nan, False, False, False],
dtype=expected_dtype,
)
tm.assert_series_equal(result, expected)

result = s.str.contains("ba")
expected = Series(
[False, False, False, True, False, False, np.nan, False, False, False]
[False, False, False, True, False, False, np.nan, False, False, False],
dtype=expected_dtype,
)
tm.assert_series_equal(result, expected)

result = s.str.contains("ba", case=False)
expected = Series(
[False, False, False, True, True, False, np.nan, True, False, False]
[False, False, False, True, True, False, np.nan, True, False, False],
dtype=expected_dtype,
)
tm.assert_series_equal(result, expected)


def test_contains_nan():
def test_contains_nan(any_string_dtype):
# PR #14171
s = Series([np.nan, np.nan, np.nan], dtype=np.object_)
s = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype)

result = s.str.contains("foo", na=False)
expected = Series([False, False, False], dtype=np.bool_)
expected_dtype = np.bool_ if any_string_dtype == "object" else "boolean"
expected = Series([False, False, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

result = s.str.contains("foo", na=True)
expected = Series([True, True, True], dtype=np.bool_)
expected = Series([True, True, True], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

result = s.str.contains("foo", na="foo")
expected = Series(["foo", "foo", "foo"], dtype=np.object_)
if any_string_dtype == "object":
expected = Series(["foo", "foo", "foo"], dtype=np.object_)
else:
expected = Series([True, True, True], dtype="boolean")
tm.assert_series_equal(result, expected)

result = s.str.contains("foo")
expected = Series([np.nan, np.nan, np.nan], dtype=np.object_)
expected_dtype = "object" if any_string_dtype == "object" else "boolean"
expected = Series([np.nan, np.nan, np.nan], dtype=expected_dtype)
tm.assert_series_equal(result, expected)


Expand Down Expand Up @@ -609,14 +687,14 @@ def test_replace_moar():
tm.assert_series_equal(result, expected)


def test_match_findall_flags():
def test_flags_kwarg(any_string_dtype):
data = {
"Dave": "[email protected]",
"Steve": "[email protected]",
"Rob": "[email protected]",
"Wes": np.nan,
}
data = Series(data)
data = Series(data, dtype=any_string_dtype)

pat = r"([A-Z0-9._%+-]+)@([A-Z0-9.-]+)\.([A-Z]{2,4})"

Expand Down

0 comments on commit 0fba740

Please sign in to comment.