Skip to content

Commit b1363b4

Browse files
authored
REF: extract_array earlier in block construction (pandas-dev#40026)
1 parent 36ff425 commit b1363b4

File tree

5 files changed

+43
-22
lines changed

5 files changed

+43
-22
lines changed

pandas/_testing/asserters.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
is_numeric_dtype,
1919
needs_i8_conversion,
2020
)
21+
from pandas.core.dtypes.dtypes import PandasDtype
2122
from pandas.core.dtypes.missing import array_equivalent
2223

2324
import pandas as pd
@@ -630,12 +631,12 @@ def raise_assert_detail(obj, message, left, right, diff=None, index_values=None)
630631

631632
if isinstance(left, np.ndarray):
632633
left = pprint_thing(left)
633-
elif is_categorical_dtype(left):
634+
elif is_categorical_dtype(left) or isinstance(left, PandasDtype):
634635
left = repr(left)
635636

636637
if isinstance(right, np.ndarray):
637638
right = pprint_thing(right)
638-
elif is_categorical_dtype(right):
639+
elif is_categorical_dtype(right) or isinstance(right, PandasDtype):
639640
right = repr(right)
640641

641642
msg += f"""

pandas/core/internals/blocks.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _maybe_coerce_values(cls, values):
178178
179179
Parameters
180180
----------
181-
values : np.ndarray, ExtensionArray, Index
181+
values : np.ndarray or ExtensionArray
182182
183183
Returns
184184
-------
@@ -350,7 +350,7 @@ def __getstate__(self):
350350
@final
351351
def __setstate__(self, state):
352352
self.mgr_locs = libinternals.BlockPlacement(state[0])
353-
self.values = state[1]
353+
self.values = extract_array(state[1], extract_numpy=True)
354354
self.ndim = self.values.ndim
355355

356356
def _slice(self, slicer):
@@ -1623,7 +1623,7 @@ def _maybe_coerce_values(cls, values):
16231623
16241624
Parameters
16251625
----------
1626-
values : Index, Series, ExtensionArray
1626+
values : np.ndarray or ExtensionArray
16271627
16281628
Returns
16291629
-------
@@ -2105,7 +2105,7 @@ def _maybe_coerce_values(cls, values):
21052105
21062106
Parameters
21072107
----------
2108-
values : array-like
2108+
values : np.ndarray or ExtensionArray
21092109
Must be convertible to datetime64/timedelta64
21102110
21112111
Returns

pandas/core/internals/managers.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from pandas.core.dtypes.dtypes import ExtensionDtype
4646
from pandas.core.dtypes.generic import (
4747
ABCDataFrame,
48-
ABCPandasArray,
4948
ABCSeries,
5049
)
5150
from pandas.core.dtypes.missing import (
@@ -316,6 +315,8 @@ def __getstate__(self):
316315
def __setstate__(self, state):
317316
def unpickle_block(values, mgr_locs, ndim: int):
318317
# TODO(EA2D): ndim would be unnecessary with 2D EAs
318+
# older pickles may store e.g. DatetimeIndex instead of DatetimeArray
319+
values = extract_array(values, extract_numpy=True)
319320
return make_block(values, placement=mgr_locs, ndim=ndim)
320321

321322
if isinstance(state, tuple) and len(state) >= 4 and "0.14.1" in state[3]:
@@ -1212,6 +1213,7 @@ def insert(self, loc: int, item: Hashable, value, allow_duplicates: bool = False
12121213
# TODO(EA2D): special case not needed with 2D EAs
12131214
value = ensure_block_shape(value, ndim=2)
12141215

1216+
# TODO: type value as ArrayLike
12151217
block = make_block(values=value, ndim=self.ndim, placement=slice(loc, loc + 1))
12161218

12171219
for blkno, count in _fast_count_smallints(self.blknos[loc:]):
@@ -1673,16 +1675,20 @@ def create_block_manager_from_blocks(blocks, axes: List[Index]) -> BlockManager:
16731675
raise construction_error(tot_items, blocks[0].shape[1:], axes, e)
16741676

16751677

1678+
# We define this here so we can override it in tests.extension.test_numpy
1679+
def _extract_array(obj):
1680+
return extract_array(obj, extract_numpy=True)
1681+
1682+
16761683
def create_block_manager_from_arrays(
16771684
arrays, names: Index, axes: List[Index]
16781685
) -> BlockManager:
16791686
assert isinstance(names, Index)
16801687
assert isinstance(axes, list)
16811688
assert all(isinstance(x, Index) for x in axes)
16821689

1683-
# ensure we dont have any PandasArrays when we call get_block_type
1684-
# Note: just calling extract_array breaks tests that patch PandasArray._typ.
1685-
arrays = [x if not isinstance(x, ABCPandasArray) else x.to_numpy() for x in arrays]
1690+
arrays = [_extract_array(x) for x in arrays]
1691+
16861692
try:
16871693
blocks = _form_blocks(arrays, names, axes)
16881694
mgr = BlockManager(blocks, axes)
@@ -1692,7 +1698,12 @@ def create_block_manager_from_arrays(
16921698
raise construction_error(len(arrays), arrays[0].shape, axes, e)
16931699

16941700

1695-
def construction_error(tot_items, block_shape, axes, e=None):
1701+
def construction_error(
1702+
tot_items: int,
1703+
block_shape: Shape,
1704+
axes: List[Index],
1705+
e: Optional[ValueError] = None,
1706+
):
16961707
""" raise a helpful message about our construction """
16971708
passed = tuple(map(int, [tot_items] + list(block_shape)))
16981709
# Correcting the user facing error message during dataframe construction
@@ -1716,7 +1727,9 @@ def construction_error(tot_items, block_shape, axes, e=None):
17161727
# -----------------------------------------------------------------------
17171728

17181729

1719-
def _form_blocks(arrays, names: Index, axes: List[Index]) -> List[Block]:
1730+
def _form_blocks(
1731+
arrays: List[ArrayLike], names: Index, axes: List[Index]
1732+
) -> List[Block]:
17201733
# put "leftover" items in float bucket, where else?
17211734
# generalize?
17221735
items_dict: DefaultDict[str, List] = defaultdict(list)
@@ -1836,21 +1849,14 @@ def _multi_blockify(tuples, dtype: Optional[Dtype] = None):
18361849

18371850
def _stack_arrays(tuples, dtype: np.dtype):
18381851

1839-
# fml
1840-
def _asarray_compat(x):
1841-
if isinstance(x, ABCSeries):
1842-
return x._values
1843-
else:
1844-
return np.asarray(x)
1845-
18461852
placement, arrays = zip(*tuples)
18471853

18481854
first = arrays[0]
18491855
shape = (len(arrays),) + first.shape
18501856

18511857
stacked = np.empty(shape, dtype=dtype)
18521858
for i, arr in enumerate(arrays):
1853-
stacked[i] = _asarray_compat(arr)
1859+
stacked[i] = arr
18541860

18551861
return stacked, placement
18561862

@@ -1874,7 +1880,7 @@ def _interleaved_dtype(blocks: Sequence[Block]) -> Optional[DtypeObj]:
18741880
return find_common_type([b.dtype for b in blocks])
18751881

18761882

1877-
def _consolidate(blocks):
1883+
def _consolidate(blocks: Tuple[Block, ...]) -> List[Block]:
18781884
"""
18791885
Merge blocks having same dtype, exclude non-consolidating blocks
18801886
"""

pandas/tests/extension/test_numpy.py

+14
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,26 @@
2020
ExtensionDtype,
2121
PandasDtype,
2222
)
23+
from pandas.core.dtypes.generic import ABCPandasArray
2324

2425
import pandas as pd
2526
import pandas._testing as tm
2627
from pandas.core.arrays.numpy_ import PandasArray
28+
from pandas.core.internals import managers
2729
from pandas.tests.extension import base
2830

2931

32+
def _extract_array_patched(obj):
33+
if isinstance(obj, (pd.Index, pd.Series)):
34+
obj = obj._values
35+
if isinstance(obj, ABCPandasArray):
36+
# TODO for reasons unclear, we get here in a couple of tests
37+
# with PandasArray._typ *not* patched
38+
obj = obj.to_numpy()
39+
40+
return obj
41+
42+
3043
@pytest.fixture(params=["float", "object"])
3144
def dtype(request):
3245
return PandasDtype(np.dtype(request.param))
@@ -51,6 +64,7 @@ def allow_in_pandas(monkeypatch):
5164
"""
5265
with monkeypatch.context() as m:
5366
m.setattr(PandasArray, "_typ", "extension")
67+
m.setattr(managers, "_extract_array", _extract_array_patched)
5468
yield
5569

5670

pandas/tests/internals/test_internals.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def create_block(typestr, placement, item_shape=None, num_offset=0):
123123
assert m is not None, f"incompatible typestr -> {typestr}"
124124
tz = m.groups()[0]
125125
assert num_items == 1, "must have only 1 num items for a tz-aware"
126-
values = DatetimeIndex(np.arange(N) * 1e9, tz=tz)
126+
values = DatetimeIndex(np.arange(N) * 1e9, tz=tz)._data
127127
elif typestr in ("timedelta", "td", "m8[ns]"):
128128
values = (mat * 1).astype("m8[ns]")
129129
elif typestr in ("category",):

0 commit comments

Comments
 (0)