1from __future__ import annotations
2
3from typing import (
4 TYPE_CHECKING,
5 NamedTuple,
6)
7
8from pandas.core.dtypes.common import is_1d_only_ea_dtype
9
10if TYPE_CHECKING:
11 from collections.abc import Iterator
12
13 from pandas._libs.internals import BlockPlacement
14 from pandas._typing import ArrayLike
15
16 from pandas.core.internals.blocks import Block
17 from pandas.core.internals.managers import BlockManager
18
19
20class BlockPairInfo(NamedTuple):
21 lvals: ArrayLike
22 rvals: ArrayLike
23 locs: BlockPlacement
24 left_ea: bool
25 right_ea: bool
26 rblk: Block
27
28
29def _iter_block_pairs(
30 left: BlockManager, right: BlockManager
31) -> Iterator[BlockPairInfo]:
32 # At this point we have already checked the parent DataFrames for
33 # assert rframe._indexed_same(lframe)
34
35 for blk in left.blocks:
36 locs = blk.mgr_locs
37 blk_vals = blk.values
38
39 left_ea = blk_vals.ndim == 1
40
41 rblks = right._slice_take_blocks_ax0(locs.indexer, only_slice=True)
42
43 # Assertions are disabled for performance, but should hold:
44 # if left_ea:
45 # assert len(locs) == 1, locs
46 # assert len(rblks) == 1, rblks
47 # assert rblks[0].shape[0] == 1, rblks[0].shape
48
49 for rblk in rblks:
50 right_ea = rblk.values.ndim == 1
51
52 lvals, rvals = _get_same_shape_values(blk, rblk, left_ea, right_ea)
53 info = BlockPairInfo(lvals, rvals, locs, left_ea, right_ea, rblk)
54 yield info
55
56
57def operate_blockwise(
58 left: BlockManager, right: BlockManager, array_op
59) -> BlockManager:
60 # At this point we have already checked the parent DataFrames for
61 # assert rframe._indexed_same(lframe)
62
63 res_blks: list[Block] = []
64 for lvals, rvals, locs, left_ea, right_ea, rblk in _iter_block_pairs(left, right):
65 res_values = array_op(lvals, rvals)
66 if (
67 left_ea
68 and not right_ea
69 and hasattr(res_values, "reshape")
70 and not is_1d_only_ea_dtype(res_values.dtype)
71 ):
72 res_values = res_values.reshape(1, -1)
73 nbs = rblk._split_op_result(res_values)
74
75 # Assertions are disabled for performance, but should hold:
76 # if right_ea or left_ea:
77 # assert len(nbs) == 1
78 # else:
79 # assert res_values.shape == lvals.shape, (res_values.shape, lvals.shape)
80
81 _reset_block_mgr_locs(nbs, locs)
82
83 res_blks.extend(nbs)
84
85 # Assertions are disabled for performance, but should hold:
86 # slocs = {y for nb in res_blks for y in nb.mgr_locs.as_array}
87 # nlocs = sum(len(nb.mgr_locs.as_array) for nb in res_blks)
88 # assert nlocs == len(left.items), (nlocs, len(left.items))
89 # assert len(slocs) == nlocs, (len(slocs), nlocs)
90 # assert slocs == set(range(nlocs)), slocs
91
92 new_mgr = type(right)(tuple(res_blks), axes=right.axes, verify_integrity=False)
93 return new_mgr
94
95
96def _reset_block_mgr_locs(nbs: list[Block], locs) -> None:
97 """
98 Reset mgr_locs to correspond to our original DataFrame.
99 """
100 for nb in nbs:
101 nblocs = locs[nb.mgr_locs.indexer]
102 nb.mgr_locs = nblocs
103 # Assertions are disabled for performance, but should hold:
104 # assert len(nblocs) == nb.shape[0], (len(nblocs), nb.shape)
105 # assert all(x in locs.as_array for x in nb.mgr_locs.as_array)
106
107
108def _get_same_shape_values(
109 lblk: Block, rblk: Block, left_ea: bool, right_ea: bool
110) -> tuple[ArrayLike, ArrayLike]:
111 """
112 Slice lblk.values to align with rblk. Squeeze if we have EAs.
113 """
114 lvals = lblk.values
115 rvals = rblk.values
116
117 # Require that the indexing into lvals be slice-like
118 assert rblk.mgr_locs.is_slice_like, rblk.mgr_locs
119
120 # TODO(EA2D): with 2D EAs only this first clause would be needed
121 if not (left_ea or right_ea):
122 # error: No overload variant of "__getitem__" of "ExtensionArray" matches
123 # argument type "Tuple[Union[ndarray, slice], slice]"
124 lvals = lvals[rblk.mgr_locs.indexer, :] # type: ignore[call-overload]
125 assert lvals.shape == rvals.shape, (lvals.shape, rvals.shape)
126 elif left_ea and right_ea:
127 assert lvals.shape == rvals.shape, (lvals.shape, rvals.shape)
128 elif right_ea:
129 # lvals are 2D, rvals are 1D
130
131 # error: No overload variant of "__getitem__" of "ExtensionArray" matches
132 # argument type "Tuple[Union[ndarray, slice], slice]"
133 lvals = lvals[rblk.mgr_locs.indexer, :] # type: ignore[call-overload]
134 assert lvals.shape[0] == 1, lvals.shape
135 lvals = lvals[0, :]
136 else:
137 # lvals are 1D, rvals are 2D
138 assert rvals.shape[0] == 1, rvals.shape
139 # error: No overload variant of "__getitem__" of "ExtensionArray" matches
140 # argument type "Tuple[int, slice]"
141 rvals = rvals[0, :] # type: ignore[call-overload]
142
143 return lvals, rvals
144
145
146def blockwise_all(left: BlockManager, right: BlockManager, op) -> bool:
147 """
148 Blockwise `all` reduction.
149 """
150 for info in _iter_block_pairs(left, right):
151 res = op(info.lvals, info.rvals)
152 if not res:
153 return False
154 return True