1from __future__ import annotations
2
3from decimal import Decimal
4import operator
5import os
6from sys import byteorder
7from typing import (
8 TYPE_CHECKING,
9 Callable,
10 ContextManager,
11 cast,
12)
13import warnings
14
15import numpy as np
16
17from pandas._config.localization import (
18 can_set_locale,
19 get_locales,
20 set_locale,
21)
22
23from pandas.compat import pa_version_under10p1
24
25from pandas.core.dtypes.common import is_string_dtype
26
27import pandas as pd
28from pandas import (
29 ArrowDtype,
30 DataFrame,
31 Index,
32 MultiIndex,
33 RangeIndex,
34 Series,
35)
36from pandas._testing._io import (
37 round_trip_localpath,
38 round_trip_pathlib,
39 round_trip_pickle,
40 write_to_compressed,
41)
42from pandas._testing._warnings import (
43 assert_produces_warning,
44 maybe_produces_warning,
45)
46from pandas._testing.asserters import (
47 assert_almost_equal,
48 assert_attr_equal,
49 assert_categorical_equal,
50 assert_class_equal,
51 assert_contains_all,
52 assert_copy,
53 assert_datetime_array_equal,
54 assert_dict_equal,
55 assert_equal,
56 assert_extension_array_equal,
57 assert_frame_equal,
58 assert_index_equal,
59 assert_indexing_slices_equivalent,
60 assert_interval_array_equal,
61 assert_is_sorted,
62 assert_is_valid_plot_return_object,
63 assert_metadata_equivalent,
64 assert_numpy_array_equal,
65 assert_period_array_equal,
66 assert_series_equal,
67 assert_sp_array_equal,
68 assert_timedelta_array_equal,
69 raise_assert_detail,
70)
71from pandas._testing.compat import (
72 get_dtype,
73 get_obj,
74)
75from pandas._testing.contexts import (
76 assert_cow_warning,
77 decompress_file,
78 ensure_clean,
79 raises_chained_assignment_error,
80 set_timezone,
81 use_numexpr,
82 with_csv_dialect,
83)
84from pandas.core.arrays import (
85 BaseMaskedArray,
86 ExtensionArray,
87 NumpyExtensionArray,
88)
89from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
90from pandas.core.construction import extract_array
91
92if TYPE_CHECKING:
93 from pandas._typing import (
94 Dtype,
95 NpDtype,
96 )
97
98 from pandas.core.arrays import ArrowExtensionArray
99
100UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"]
101UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"]
102SIGNED_INT_NUMPY_DTYPES: list[NpDtype] = [int, "int8", "int16", "int32", "int64"]
103SIGNED_INT_EA_DTYPES: list[Dtype] = ["Int8", "Int16", "Int32", "Int64"]
104ALL_INT_NUMPY_DTYPES = UNSIGNED_INT_NUMPY_DTYPES + SIGNED_INT_NUMPY_DTYPES
105ALL_INT_EA_DTYPES = UNSIGNED_INT_EA_DTYPES + SIGNED_INT_EA_DTYPES
106ALL_INT_DTYPES: list[Dtype] = [*ALL_INT_NUMPY_DTYPES, *ALL_INT_EA_DTYPES]
107
108FLOAT_NUMPY_DTYPES: list[NpDtype] = [float, "float32", "float64"]
109FLOAT_EA_DTYPES: list[Dtype] = ["Float32", "Float64"]
110ALL_FLOAT_DTYPES: list[Dtype] = [*FLOAT_NUMPY_DTYPES, *FLOAT_EA_DTYPES]
111
112COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
113STRING_DTYPES: list[Dtype] = [str, "str", "U"]
114COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES]
115
116DATETIME64_DTYPES: list[Dtype] = ["datetime64[ns]", "M8[ns]"]
117TIMEDELTA64_DTYPES: list[Dtype] = ["timedelta64[ns]", "m8[ns]"]
118
119BOOL_DTYPES: list[Dtype] = [bool, "bool"]
120BYTES_DTYPES: list[Dtype] = [bytes, "bytes"]
121OBJECT_DTYPES: list[Dtype] = [object, "object"]
122
123ALL_REAL_NUMPY_DTYPES = FLOAT_NUMPY_DTYPES + ALL_INT_NUMPY_DTYPES
124ALL_REAL_EXTENSION_DTYPES = FLOAT_EA_DTYPES + ALL_INT_EA_DTYPES
125ALL_REAL_DTYPES: list[Dtype] = [*ALL_REAL_NUMPY_DTYPES, *ALL_REAL_EXTENSION_DTYPES]
126ALL_NUMERIC_DTYPES: list[Dtype] = [*ALL_REAL_DTYPES, *COMPLEX_DTYPES]
127
128ALL_NUMPY_DTYPES = (
129 ALL_REAL_NUMPY_DTYPES
130 + COMPLEX_DTYPES
131 + STRING_DTYPES
132 + DATETIME64_DTYPES
133 + TIMEDELTA64_DTYPES
134 + BOOL_DTYPES
135 + OBJECT_DTYPES
136 + BYTES_DTYPES
137)
138
139NARROW_NP_DTYPES = [
140 np.float16,
141 np.float32,
142 np.int8,
143 np.int16,
144 np.int32,
145 np.uint8,
146 np.uint16,
147 np.uint32,
148]
149
150PYTHON_DATA_TYPES = [
151 str,
152 int,
153 float,
154 complex,
155 list,
156 tuple,
157 range,
158 dict,
159 set,
160 frozenset,
161 bool,
162 bytes,
163 bytearray,
164 memoryview,
165]
166
167ENDIAN = {"little": "<", "big": ">"}[byteorder]
168
169NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")]
170NP_NAT_OBJECTS = [
171 cls("NaT", unit)
172 for cls in [np.datetime64, np.timedelta64]
173 for unit in [
174 "Y",
175 "M",
176 "W",
177 "D",
178 "h",
179 "m",
180 "s",
181 "ms",
182 "us",
183 "ns",
184 "ps",
185 "fs",
186 "as",
187 ]
188]
189
190if not pa_version_under10p1:
191 import pyarrow as pa
192
193 UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
194 SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
195 ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES
196 ALL_INT_PYARROW_DTYPES_STR_REPR = [
197 str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES
198 ]
199
200 # pa.float16 doesn't seem supported
201 # https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
202 FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
203 FLOAT_PYARROW_DTYPES_STR_REPR = [
204 str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
205 ]
206 DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
207 STRING_PYARROW_DTYPES = [pa.string()]
208 BINARY_PYARROW_DTYPES = [pa.binary()]
209
210 TIME_PYARROW_DTYPES = [
211 pa.time32("s"),
212 pa.time32("ms"),
213 pa.time64("us"),
214 pa.time64("ns"),
215 ]
216 DATE_PYARROW_DTYPES = [pa.date32(), pa.date64()]
217 DATETIME_PYARROW_DTYPES = [
218 pa.timestamp(unit=unit, tz=tz)
219 for unit in ["s", "ms", "us", "ns"]
220 for tz in [None, "UTC", "US/Pacific", "US/Eastern"]
221 ]
222 TIMEDELTA_PYARROW_DTYPES = [pa.duration(unit) for unit in ["s", "ms", "us", "ns"]]
223
224 BOOL_PYARROW_DTYPES = [pa.bool_()]
225
226 # TODO: Add container like pyarrow types:
227 # https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions
228 ALL_PYARROW_DTYPES = (
229 ALL_INT_PYARROW_DTYPES
230 + FLOAT_PYARROW_DTYPES
231 + DECIMAL_PYARROW_DTYPES
232 + STRING_PYARROW_DTYPES
233 + BINARY_PYARROW_DTYPES
234 + TIME_PYARROW_DTYPES
235 + DATE_PYARROW_DTYPES
236 + DATETIME_PYARROW_DTYPES
237 + TIMEDELTA_PYARROW_DTYPES
238 + BOOL_PYARROW_DTYPES
239 )
240 ALL_REAL_PYARROW_DTYPES_STR_REPR = (
241 ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR
242 )
243else:
244 FLOAT_PYARROW_DTYPES_STR_REPR = []
245 ALL_INT_PYARROW_DTYPES_STR_REPR = []
246 ALL_PYARROW_DTYPES = []
247 ALL_REAL_PYARROW_DTYPES_STR_REPR = []
248
249ALL_REAL_NULLABLE_DTYPES = (
250 FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR
251)
252
253arithmetic_dunder_methods = [
254 "__add__",
255 "__radd__",
256 "__sub__",
257 "__rsub__",
258 "__mul__",
259 "__rmul__",
260 "__floordiv__",
261 "__rfloordiv__",
262 "__truediv__",
263 "__rtruediv__",
264 "__pow__",
265 "__rpow__",
266 "__mod__",
267 "__rmod__",
268]
269
270comparison_dunder_methods = ["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"]
271
272
273# -----------------------------------------------------------------------------
274# Comparators
275
276
277def box_expected(expected, box_cls, transpose: bool = True):
278 """
279 Helper function to wrap the expected output of a test in a given box_class.
280
281 Parameters
282 ----------
283 expected : np.ndarray, Index, Series
284 box_cls : {Index, Series, DataFrame}
285
286 Returns
287 -------
288 subclass of box_cls
289 """
290 if box_cls is pd.array:
291 if isinstance(expected, RangeIndex):
292 # pd.array would return an IntegerArray
293 expected = NumpyExtensionArray(np.asarray(expected._values))
294 else:
295 expected = pd.array(expected, copy=False)
296 elif box_cls is Index:
297 with warnings.catch_warnings():
298 warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
299 expected = Index(expected)
300 elif box_cls is Series:
301 with warnings.catch_warnings():
302 warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
303 expected = Series(expected)
304 elif box_cls is DataFrame:
305 with warnings.catch_warnings():
306 warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
307 expected = Series(expected).to_frame()
308 if transpose:
309 # for vector operations, we need a DataFrame to be a single-row,
310 # not a single-column, in order to operate against non-DataFrame
311 # vectors of the same length. But convert to two rows to avoid
312 # single-row special cases in datetime arithmetic
313 expected = expected.T
314 expected = pd.concat([expected] * 2, ignore_index=True)
315 elif box_cls is np.ndarray or box_cls is np.array:
316 expected = np.array(expected)
317 elif box_cls is to_array:
318 expected = to_array(expected)
319 else:
320 raise NotImplementedError(box_cls)
321 return expected
322
323
324def to_array(obj):
325 """
326 Similar to pd.array, but does not cast numpy dtypes to nullable dtypes.
327 """
328 # temporary implementation until we get pd.array in place
329 dtype = getattr(obj, "dtype", None)
330
331 if dtype is None:
332 return np.asarray(obj)
333
334 return extract_array(obj, extract_numpy=True)
335
336
337class SubclassedSeries(Series):
338 _metadata = ["testattr", "name"]
339
340 @property
341 def _constructor(self):
342 # For testing, those properties return a generic callable, and not
343 # the actual class. In this case that is equivalent, but it is to
344 # ensure we don't rely on the property returning a class
345 # See https://github.com/pandas-dev/pandas/pull/46018 and
346 # https://github.com/pandas-dev/pandas/issues/32638 and linked issues
347 return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
348
349 @property
350 def _constructor_expanddim(self):
351 return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
352
353
354class SubclassedDataFrame(DataFrame):
355 _metadata = ["testattr"]
356
357 @property
358 def _constructor(self):
359 return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
360
361 @property
362 def _constructor_sliced(self):
363 return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
364
365
366def convert_rows_list_to_csv_str(rows_list: list[str]) -> str:
367 """
368 Convert list of CSV rows to single CSV-formatted string for current OS.
369
370 This method is used for creating expected value of to_csv() method.
371
372 Parameters
373 ----------
374 rows_list : List[str]
375 Each element represents the row of csv.
376
377 Returns
378 -------
379 str
380 Expected output of to_csv() in current OS.
381 """
382 sep = os.linesep
383 return sep.join(rows_list) + sep
384
385
386def external_error_raised(expected_exception: type[Exception]) -> ContextManager:
387 """
388 Helper function to mark pytest.raises that have an external error message.
389
390 Parameters
391 ----------
392 expected_exception : Exception
393 Expected error to raise.
394
395 Returns
396 -------
397 Callable
398 Regular `pytest.raises` function with `match` equal to `None`.
399 """
400 import pytest
401
402 return pytest.raises(expected_exception, match=None)
403
404
405cython_table = pd.core.common._cython_table.items()
406
407
408def get_cython_table_params(ndframe, func_names_and_expected):
409 """
410 Combine frame, functions from com._cython_table
411 keys and expected result.
412
413 Parameters
414 ----------
415 ndframe : DataFrame or Series
416 func_names_and_expected : Sequence of two items
417 The first item is a name of a NDFrame method ('sum', 'prod') etc.
418 The second item is the expected return value.
419
420 Returns
421 -------
422 list
423 List of three items (DataFrame, function, expected result)
424 """
425 results = []
426 for func_name, expected in func_names_and_expected:
427 results.append((ndframe, func_name, expected))
428 results += [
429 (ndframe, func, expected)
430 for func, name in cython_table
431 if name == func_name
432 ]
433 return results
434
435
436def get_op_from_name(op_name: str) -> Callable:
437 """
438 The operator function for a given op name.
439
440 Parameters
441 ----------
442 op_name : str
443 The op name, in form of "add" or "__add__".
444
445 Returns
446 -------
447 function
448 A function performing the operation.
449 """
450 short_opname = op_name.strip("_")
451 try:
452 op = getattr(operator, short_opname)
453 except AttributeError:
454 # Assume it is the reverse operator
455 rop = getattr(operator, short_opname[1:])
456 op = lambda x, y: rop(y, x)
457
458 return op
459
460
461# -----------------------------------------------------------------------------
462# Indexing test helpers
463
464
465def getitem(x):
466 return x
467
468
469def setitem(x):
470 return x
471
472
473def loc(x):
474 return x.loc
475
476
477def iloc(x):
478 return x.iloc
479
480
481def at(x):
482 return x.at
483
484
485def iat(x):
486 return x.iat
487
488
489# -----------------------------------------------------------------------------
490
491_UNITS = ["s", "ms", "us", "ns"]
492
493
494def get_finest_unit(left: str, right: str):
495 """
496 Find the higher of two datetime64 units.
497 """
498 if _UNITS.index(left) >= _UNITS.index(right):
499 return left
500 return right
501
502
503def shares_memory(left, right) -> bool:
504 """
505 Pandas-compat for np.shares_memory.
506 """
507 if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
508 return np.shares_memory(left, right)
509 elif isinstance(left, np.ndarray):
510 # Call with reversed args to get to unpacking logic below.
511 return shares_memory(right, left)
512
513 if isinstance(left, RangeIndex):
514 return False
515 if isinstance(left, MultiIndex):
516 return shares_memory(left._codes, right)
517 if isinstance(left, (Index, Series)):
518 return shares_memory(left._values, right)
519
520 if isinstance(left, NDArrayBackedExtensionArray):
521 return shares_memory(left._ndarray, right)
522 if isinstance(left, pd.core.arrays.SparseArray):
523 return shares_memory(left.sp_values, right)
524 if isinstance(left, pd.core.arrays.IntervalArray):
525 return shares_memory(left._left, right) or shares_memory(left._right, right)
526
527 if (
528 isinstance(left, ExtensionArray)
529 and is_string_dtype(left.dtype)
530 and left.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
531 ):
532 # https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
533 left = cast("ArrowExtensionArray", left)
534 if (
535 isinstance(right, ExtensionArray)
536 and is_string_dtype(right.dtype)
537 and right.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
538 ):
539 right = cast("ArrowExtensionArray", right)
540 left_pa_data = left._pa_array
541 right_pa_data = right._pa_array
542 left_buf1 = left_pa_data.chunk(0).buffers()[1]
543 right_buf1 = right_pa_data.chunk(0).buffers()[1]
544 return left_buf1 == right_buf1
545
546 if isinstance(left, BaseMaskedArray) and isinstance(right, BaseMaskedArray):
547 # By convention, we'll say these share memory if they share *either*
548 # the _data or the _mask
549 return np.shares_memory(left._data, right._data) or np.shares_memory(
550 left._mask, right._mask
551 )
552
553 if isinstance(left, DataFrame) and len(left._mgr.arrays) == 1:
554 arr = left._mgr.arrays[0]
555 return shares_memory(arr, right)
556
557 raise NotImplementedError(type(left), type(right))
558
559
560__all__ = [
561 "ALL_INT_EA_DTYPES",
562 "ALL_INT_NUMPY_DTYPES",
563 "ALL_NUMPY_DTYPES",
564 "ALL_REAL_NUMPY_DTYPES",
565 "assert_almost_equal",
566 "assert_attr_equal",
567 "assert_categorical_equal",
568 "assert_class_equal",
569 "assert_contains_all",
570 "assert_copy",
571 "assert_datetime_array_equal",
572 "assert_dict_equal",
573 "assert_equal",
574 "assert_extension_array_equal",
575 "assert_frame_equal",
576 "assert_index_equal",
577 "assert_indexing_slices_equivalent",
578 "assert_interval_array_equal",
579 "assert_is_sorted",
580 "assert_is_valid_plot_return_object",
581 "assert_metadata_equivalent",
582 "assert_numpy_array_equal",
583 "assert_period_array_equal",
584 "assert_produces_warning",
585 "assert_series_equal",
586 "assert_sp_array_equal",
587 "assert_timedelta_array_equal",
588 "assert_cow_warning",
589 "at",
590 "BOOL_DTYPES",
591 "box_expected",
592 "BYTES_DTYPES",
593 "can_set_locale",
594 "COMPLEX_DTYPES",
595 "convert_rows_list_to_csv_str",
596 "DATETIME64_DTYPES",
597 "decompress_file",
598 "ENDIAN",
599 "ensure_clean",
600 "external_error_raised",
601 "FLOAT_EA_DTYPES",
602 "FLOAT_NUMPY_DTYPES",
603 "get_cython_table_params",
604 "get_dtype",
605 "getitem",
606 "get_locales",
607 "get_finest_unit",
608 "get_obj",
609 "get_op_from_name",
610 "iat",
611 "iloc",
612 "loc",
613 "maybe_produces_warning",
614 "NARROW_NP_DTYPES",
615 "NP_NAT_OBJECTS",
616 "NULL_OBJECTS",
617 "OBJECT_DTYPES",
618 "raise_assert_detail",
619 "raises_chained_assignment_error",
620 "round_trip_localpath",
621 "round_trip_pathlib",
622 "round_trip_pickle",
623 "setitem",
624 "set_locale",
625 "set_timezone",
626 "shares_memory",
627 "SIGNED_INT_EA_DTYPES",
628 "SIGNED_INT_NUMPY_DTYPES",
629 "STRING_DTYPES",
630 "SubclassedDataFrame",
631 "SubclassedSeries",
632 "TIMEDELTA64_DTYPES",
633 "to_array",
634 "UNSIGNED_INT_EA_DTYPES",
635 "UNSIGNED_INT_NUMPY_DTYPES",
636 "use_numexpr",
637 "with_csv_dialect",
638 "write_to_compressed",
639]