Skip to content

Commit 30021ac

Browse files
[ArrayManager] GroupBy cython aggregations (no fallback) (pandas-dev#39885)
1 parent 2724350 commit 30021ac

File tree

7 files changed

+86
-37
lines changed

7 files changed

+86
-37
lines changed

.github/workflows/ci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ jobs:
157157
pytest pandas/tests/reductions/ --array-manager
158158
pytest pandas/tests/generic/test_generic.py --array-manager
159159
pytest pandas/tests/arithmetic/ --array-manager
160+
pytest pandas/tests/groupby/aggregate/ --array-manager
160161
pytest pandas/tests/reshape/merge --array-manager
161162
162163
# indexing subset (temporary since other tests don't pass yet)

pandas/core/groupby/generic.py

+25-15
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
ArrayLike,
4242
FrameOrSeries,
4343
FrameOrSeriesUnion,
44+
Manager,
4445
)
4546
from pandas.util._decorators import (
4647
Appender,
@@ -107,7 +108,10 @@
107108
all_indexes_same,
108109
)
109110
import pandas.core.indexes.base as ibase
110-
from pandas.core.internals import BlockManager
111+
from pandas.core.internals import (
112+
ArrayManager,
113+
BlockManager,
114+
)
111115
from pandas.core.series import Series
112116
from pandas.core.util.numba_ import maybe_use_numba
113117

@@ -1074,20 +1078,22 @@ def _iterate_slices(self) -> Iterable[Series]:
10741078
def _cython_agg_general(
10751079
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
10761080
) -> DataFrame:
1077-
agg_mgr = self._cython_agg_blocks(
1081+
agg_mgr = self._cython_agg_manager(
10781082
how, alt=alt, numeric_only=numeric_only, min_count=min_count
10791083
)
10801084
return self._wrap_agged_manager(agg_mgr)
10811085

1082-
def _cython_agg_blocks(
1086+
def _cython_agg_manager(
10831087
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
1084-
) -> BlockManager:
1088+
) -> Manager:
10851089

1086-
data: BlockManager = self._get_data_to_aggregate()
1090+
data: Manager = self._get_data_to_aggregate()
10871091

10881092
if numeric_only:
10891093
data = data.get_numeric_data(copy=False)
10901094

1095+
using_array_manager = isinstance(data, ArrayManager)
1096+
10911097
def cast_agg_result(result, values: ArrayLike, how: str) -> ArrayLike:
10921098
# see if we can cast the values to the desired dtype
10931099
# this may not be the original dtype
@@ -1101,7 +1107,11 @@ def cast_agg_result(result, values: ArrayLike, how: str) -> ArrayLike:
11011107
result = type(values)._from_sequence(result.ravel(), dtype=values.dtype)
11021108
# Note this will have result.dtype == dtype from above
11031109

1104-
elif isinstance(result, np.ndarray) and result.ndim == 1:
1110+
elif (
1111+
not using_array_manager
1112+
and isinstance(result, np.ndarray)
1113+
and result.ndim == 1
1114+
):
11051115
# We went through a SeriesGroupByPath and need to reshape
11061116
# GH#32223 includes case with IntegerArray values
11071117
result = result.reshape(1, -1)
@@ -1153,11 +1163,11 @@ def py_fallback(bvalues: ArrayLike) -> ArrayLike:
11531163
result = mgr.blocks[0].values
11541164
return result
11551165

1156-
def blk_func(bvalues: ArrayLike) -> ArrayLike:
1166+
def array_func(values: ArrayLike) -> ArrayLike:
11571167

11581168
try:
11591169
result = self.grouper._cython_operation(
1160-
"aggregate", bvalues, how, axis=1, min_count=min_count
1170+
"aggregate", values, how, axis=1, min_count=min_count
11611171
)
11621172
except NotImplementedError:
11631173
# generally if we have numeric_only=False
@@ -1170,14 +1180,14 @@ def blk_func(bvalues: ArrayLike) -> ArrayLike:
11701180
assert how == "ohlc"
11711181
raise
11721182

1173-
result = py_fallback(bvalues)
1183+
result = py_fallback(values)
11741184

1175-
return cast_agg_result(result, bvalues, how)
1185+
return cast_agg_result(result, values, how)
11761186

11771187
# TypeError -> we may have an exception in trying to aggregate
11781188
# continue and exclude the block
11791189
# NotImplementedError -> "ohlc" with wrong dtype
1180-
new_mgr = data.grouped_reduce(blk_func, ignore_failures=True)
1190+
new_mgr = data.grouped_reduce(array_func, ignore_failures=True)
11811191

11821192
if not len(new_mgr):
11831193
raise DataError("No numeric types to aggregate")
@@ -1670,7 +1680,7 @@ def _wrap_frame_output(self, result, obj: DataFrame) -> DataFrame:
16701680
else:
16711681
return self.obj._constructor(result, index=obj.index, columns=result_index)
16721682

1673-
def _get_data_to_aggregate(self) -> BlockManager:
1683+
def _get_data_to_aggregate(self) -> Manager:
16741684
obj = self._obj_with_exclusions
16751685
if self.axis == 1:
16761686
return obj.T._mgr
@@ -1755,17 +1765,17 @@ def _wrap_transformed_output(
17551765

17561766
return result
17571767

1758-
def _wrap_agged_manager(self, mgr: BlockManager) -> DataFrame:
1768+
def _wrap_agged_manager(self, mgr: Manager) -> DataFrame:
17591769
if not self.as_index:
17601770
index = np.arange(mgr.shape[1])
1761-
mgr.axes[1] = ibase.Index(index)
1771+
mgr.set_axis(1, ibase.Index(index), verify_integrity=False)
17621772
result = self.obj._constructor(mgr)
17631773

17641774
self._insert_inaxis_grouper_inplace(result)
17651775
result = result._consolidate()
17661776
else:
17671777
index = self.grouper.result_index
1768-
mgr.axes[1] = index
1778+
mgr.set_axis(1, index, verify_integrity=False)
17691779
result = self.obj._constructor(mgr)
17701780

17711781
if self.axis == 1:

pandas/core/internals/array_manager.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -150,18 +150,20 @@ def _normalize_axis(axis):
150150
axis = 1 if axis == 0 else 0
151151
return axis
152152

153-
# TODO can be shared
154-
def set_axis(self, axis: int, new_labels: Index) -> None:
153+
def set_axis(
154+
self, axis: int, new_labels: Index, verify_integrity: bool = True
155+
) -> None:
155156
# Caller is responsible for ensuring we have an Index object.
156157
axis = self._normalize_axis(axis)
157-
old_len = len(self._axes[axis])
158-
new_len = len(new_labels)
158+
if verify_integrity:
159+
old_len = len(self._axes[axis])
160+
new_len = len(new_labels)
159161

160-
if new_len != old_len:
161-
raise ValueError(
162-
f"Length mismatch: Expected axis has {old_len} elements, new "
163-
f"values have {new_len} elements"
164-
)
162+
if new_len != old_len:
163+
raise ValueError(
164+
f"Length mismatch: Expected axis has {old_len} elements, new "
165+
f"values have {new_len} elements"
166+
)
165167

166168
self._axes[axis] = new_labels
167169

@@ -254,6 +256,30 @@ def reduce(
254256
new_mgr = type(self)(result_arrays, [index, columns])
255257
return new_mgr, indexer
256258

259+
def grouped_reduce(self: T, func: Callable, ignore_failures: bool = False) -> T:
260+
"""
261+
Apply grouped reduction function columnwise, returning a new ArrayManager.
262+
263+
Parameters
264+
----------
265+
func : grouped reduction function
266+
ignore_failures : bool, default False
267+
Whether to drop columns where func raises TypeError.
268+
269+
Returns
270+
-------
271+
ArrayManager
272+
"""
273+
# TODO ignore_failures
274+
result_arrays = [func(arr) for arr in self.arrays]
275+
276+
if len(result_arrays) == 0:
277+
index = Index([None]) # placeholder
278+
else:
279+
index = Index(range(result_arrays[0].shape[0]))
280+
281+
return type(self)(result_arrays, [index, self.items])
282+
257283
def operate_blockwise(self, other: ArrayManager, array_op) -> ArrayManager:
258284
"""
259285
Apply array_op blockwise with another (aligned) BlockManager.
@@ -369,7 +395,7 @@ def apply_with_block(self: T, f, align_keys=None, **kwargs) -> T:
369395
if hasattr(arr, "tz") and arr.tz is None: # type: ignore[union-attr]
370396
# DatetimeArray needs to be converted to ndarray for DatetimeBlock
371397
arr = arr._data # type: ignore[union-attr]
372-
elif arr.dtype.kind == "m":
398+
elif arr.dtype.kind == "m" and not isinstance(arr, np.ndarray):
373399
# TimedeltaArray needs to be converted to ndarray for TimedeltaBlock
374400
arr = arr._data # type: ignore[union-attr]
375401
if isinstance(arr, np.ndarray):

pandas/core/internals/managers.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -234,16 +234,19 @@ def shape(self) -> Shape:
234234
def ndim(self) -> int:
235235
return len(self.axes)
236236

237-
def set_axis(self, axis: int, new_labels: Index) -> None:
237+
def set_axis(
238+
self, axis: int, new_labels: Index, verify_integrity: bool = True
239+
) -> None:
238240
# Caller is responsible for ensuring we have an Index object.
239-
old_len = len(self.axes[axis])
240-
new_len = len(new_labels)
241+
if verify_integrity:
242+
old_len = len(self.axes[axis])
243+
new_len = len(new_labels)
241244

242-
if new_len != old_len:
243-
raise ValueError(
244-
f"Length mismatch: Expected axis has {old_len} elements, new "
245-
f"values have {new_len} elements"
246-
)
245+
if new_len != old_len:
246+
raise ValueError(
247+
f"Length mismatch: Expected axis has {old_len} elements, new "
248+
f"values have {new_len} elements"
249+
)
247250

248251
self.axes[axis] = new_labels
249252

@@ -282,16 +285,15 @@ def get_dtypes(self):
282285
return algos.take_nd(dtypes, self.blknos, allow_fill=False)
283286

284287
@property
285-
def arrays(self):
288+
def arrays(self) -> List[ArrayLike]:
286289
"""
287290
Quick access to the backing arrays of the Blocks.
288291
289292
Only for compatibility with ArrayManager for testing convenience.
290293
Not to be used in actual code, and return value is not the same as the
291294
ArrayManager method (list of 1D arrays vs iterator of 2D ndarrays / 1D EAs).
292295
"""
293-
for blk in self.blocks:
294-
yield blk.values
296+
return [blk.values for blk in self.blocks]
295297

296298
def __getstate__(self):
297299
block_values = [b.values for b in self.blocks]

pandas/tests/groupby/aggregate/test_aggregate.py

+7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111

1212
from pandas.errors import PerformanceWarning
13+
import pandas.util._test_decorators as td
1314

1415
from pandas.core.dtypes.common import is_integer_dtype
1516

@@ -45,6 +46,7 @@ def test_agg_regression1(tsframe):
4546
tm.assert_frame_equal(result, expected)
4647

4748

49+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) quantile/describe
4850
def test_agg_must_agg(df):
4951
grouped = df.groupby("A")["C"]
5052

@@ -134,6 +136,7 @@ def test_groupby_aggregation_multi_level_column():
134136
tm.assert_frame_equal(result, expected)
135137

136138

139+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) non-cython agg
137140
def test_agg_apply_corner(ts, tsframe):
138141
# nothing to group, all NA
139142
grouped = ts.groupby(ts * np.nan)
@@ -212,6 +215,7 @@ def test_aggregate_str_func(tsframe, groupbyfunc):
212215
tm.assert_frame_equal(result, expected)
213216

214217

218+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) non-cython agg
215219
def test_agg_str_with_kwarg_axis_1_raises(df, reduction_func):
216220
gb = df.groupby(level=0)
217221
if reduction_func in ("idxmax", "idxmin"):
@@ -491,6 +495,7 @@ def test_agg_index_has_complex_internals(index):
491495
tm.assert_frame_equal(result, expected)
492496

493497

498+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) agg py_fallback
494499
def test_agg_split_block():
495500
# https://github.com/pandas-dev/pandas/issues/31522
496501
df = DataFrame(
@@ -508,6 +513,7 @@ def test_agg_split_block():
508513
tm.assert_frame_equal(result, expected)
509514

510515

516+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) agg py_fallback
511517
def test_agg_split_object_part_datetime():
512518
# https://github.com/pandas-dev/pandas/pull/31616
513519
df = DataFrame(
@@ -1199,6 +1205,7 @@ def test_aggregate_datetime_objects():
11991205
tm.assert_series_equal(result, expected)
12001206

12011207

1208+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) agg py_fallback
12021209
def test_aggregate_numeric_object_dtype():
12031210
# https://github.com/pandas-dev/pandas/issues/39329
12041211
# simplified case: multiple object columns where one is all-NaN

pandas/tests/groupby/aggregate/test_cython.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def test_read_only_buffer_source_agg(agg):
281281
"species": ["setosa", "setosa", "setosa", "setosa", "setosa"],
282282
}
283283
)
284-
df._mgr.blocks[0].values.flags.writeable = False
284+
df._mgr.arrays[0].flags.writeable = False
285285

286286
result = df.groupby(["species"]).agg({"sepal_length": agg})
287287
expected = df.copy().groupby(["species"]).agg({"sepal_length": agg})

pandas/tests/groupby/aggregate/test_other.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import numpy as np
99
import pytest
1010

11+
import pandas.util._test_decorators as td
12+
1113
import pandas as pd
1214
from pandas import (
1315
DataFrame,
@@ -412,6 +414,7 @@ def __call__(self, x):
412414
tm.assert_frame_equal(result, expected)
413415

414416

417+
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) columns with ndarrays
415418
def test_agg_over_numpy_arrays():
416419
# GH 3788
417420
df = DataFrame(

0 commit comments

Comments
 (0)