1from __future__ import annotations
2
3import collections
4from datetime import datetime
5from decimal import Decimal
6import operator
7import os
8import re
9import string
10from sys import byteorder
11from typing import (
12 TYPE_CHECKING,
13 Callable,
14 ContextManager,
15 Counter,
16 Iterable,
17 cast,
18)
19
20import numpy as np
21
22from pandas._config.localization import (
23 can_set_locale,
24 get_locales,
25 set_locale,
26)
27
28from pandas._typing import (
29 Dtype,
30 Frequency,
31 NpDtype,
32)
33from pandas.compat import pa_version_under7p0
34
35from pandas.core.dtypes.common import (
36 is_float_dtype,
37 is_integer_dtype,
38 is_sequence,
39 is_signed_integer_dtype,
40 is_unsigned_integer_dtype,
41 pandas_dtype,
42)
43
44import pandas as pd
45from pandas import (
46 ArrowDtype,
47 Categorical,
48 CategoricalIndex,
49 DataFrame,
50 DatetimeIndex,
51 Index,
52 IntervalIndex,
53 MultiIndex,
54 RangeIndex,
55 Series,
56 bdate_range,
57)
58from pandas._testing._io import (
59 close,
60 network,
61 round_trip_localpath,
62 round_trip_pathlib,
63 round_trip_pickle,
64 write_to_compressed,
65)
66from pandas._testing._random import (
67 rands,
68 rands_array,
69)
70from pandas._testing._warnings import (
71 assert_produces_warning,
72 maybe_produces_warning,
73)
74from pandas._testing.asserters import (
75 assert_almost_equal,
76 assert_attr_equal,
77 assert_categorical_equal,
78 assert_class_equal,
79 assert_contains_all,
80 assert_copy,
81 assert_datetime_array_equal,
82 assert_dict_equal,
83 assert_equal,
84 assert_extension_array_equal,
85 assert_frame_equal,
86 assert_index_equal,
87 assert_indexing_slices_equivalent,
88 assert_interval_array_equal,
89 assert_is_sorted,
90 assert_is_valid_plot_return_object,
91 assert_metadata_equivalent,
92 assert_numpy_array_equal,
93 assert_period_array_equal,
94 assert_series_equal,
95 assert_sp_array_equal,
96 assert_timedelta_array_equal,
97 raise_assert_detail,
98)
99from pandas._testing.compat import (
100 get_dtype,
101 get_obj,
102)
103from pandas._testing.contexts import (
104 decompress_file,
105 ensure_clean,
106 ensure_safe_environment_variables,
107 raises_chained_assignment_error,
108 set_timezone,
109 use_numexpr,
110 with_csv_dialect,
111)
112from pandas.core.arrays import (
113 BaseMaskedArray,
114 ExtensionArray,
115 PandasArray,
116)
117from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
118from pandas.core.construction import extract_array
119
120if TYPE_CHECKING:
121 from pandas import (
122 PeriodIndex,
123 TimedeltaIndex,
124 )
125 from pandas.core.arrays import ArrowExtensionArray
126
127_N = 30
128_K = 4
129
130UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"]
131UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"]
132SIGNED_INT_NUMPY_DTYPES: list[NpDtype] = [int, "int8", "int16", "int32", "int64"]
133SIGNED_INT_EA_DTYPES: list[Dtype] = ["Int8", "Int16", "Int32", "Int64"]
134ALL_INT_NUMPY_DTYPES = UNSIGNED_INT_NUMPY_DTYPES + SIGNED_INT_NUMPY_DTYPES
135ALL_INT_EA_DTYPES = UNSIGNED_INT_EA_DTYPES + SIGNED_INT_EA_DTYPES
136ALL_INT_DTYPES: list[Dtype] = [*ALL_INT_NUMPY_DTYPES, *ALL_INT_EA_DTYPES]
137
138FLOAT_NUMPY_DTYPES: list[NpDtype] = [float, "float32", "float64"]
139FLOAT_EA_DTYPES: list[Dtype] = ["Float32", "Float64"]
140ALL_FLOAT_DTYPES: list[Dtype] = [*FLOAT_NUMPY_DTYPES, *FLOAT_EA_DTYPES]
141
142COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
143STRING_DTYPES: list[Dtype] = [str, "str", "U"]
144
145DATETIME64_DTYPES: list[Dtype] = ["datetime64[ns]", "M8[ns]"]
146TIMEDELTA64_DTYPES: list[Dtype] = ["timedelta64[ns]", "m8[ns]"]
147
148BOOL_DTYPES: list[Dtype] = [bool, "bool"]
149BYTES_DTYPES: list[Dtype] = [bytes, "bytes"]
150OBJECT_DTYPES: list[Dtype] = [object, "object"]
151
152ALL_REAL_NUMPY_DTYPES = FLOAT_NUMPY_DTYPES + ALL_INT_NUMPY_DTYPES
153ALL_REAL_EXTENSION_DTYPES = FLOAT_EA_DTYPES + ALL_INT_EA_DTYPES
154ALL_REAL_DTYPES: list[Dtype] = [*ALL_REAL_NUMPY_DTYPES, *ALL_REAL_EXTENSION_DTYPES]
155ALL_NUMERIC_DTYPES: list[Dtype] = [*ALL_REAL_DTYPES, *COMPLEX_DTYPES]
156
157ALL_NUMPY_DTYPES = (
158 ALL_REAL_NUMPY_DTYPES
159 + COMPLEX_DTYPES
160 + STRING_DTYPES
161 + DATETIME64_DTYPES
162 + TIMEDELTA64_DTYPES
163 + BOOL_DTYPES
164 + OBJECT_DTYPES
165 + BYTES_DTYPES
166)
167
168NARROW_NP_DTYPES = [
169 np.float16,
170 np.float32,
171 np.int8,
172 np.int16,
173 np.int32,
174 np.uint8,
175 np.uint16,
176 np.uint32,
177]
178
179ENDIAN = {"little": "<", "big": ">"}[byteorder]
180
181NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")]
182NP_NAT_OBJECTS = [
183 cls("NaT", unit)
184 for cls in [np.datetime64, np.timedelta64]
185 for unit in [
186 "Y",
187 "M",
188 "W",
189 "D",
190 "h",
191 "m",
192 "s",
193 "ms",
194 "us",
195 "ns",
196 "ps",
197 "fs",
198 "as",
199 ]
200]
201
202if not pa_version_under7p0:
203 import pyarrow as pa
204
205 UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
206 SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
207 ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES
208 ALL_INT_PYARROW_DTYPES_STR_REPR = [
209 str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES
210 ]
211
212 # pa.float16 doesn't seem supported
213 # https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
214 FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
215 FLOAT_PYARROW_DTYPES_STR_REPR = [
216 str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
217 ]
218 DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
219 STRING_PYARROW_DTYPES = [pa.string()]
220 BINARY_PYARROW_DTYPES = [pa.binary()]
221
222 TIME_PYARROW_DTYPES = [
223 pa.time32("s"),
224 pa.time32("ms"),
225 pa.time64("us"),
226 pa.time64("ns"),
227 ]
228 DATE_PYARROW_DTYPES = [pa.date32(), pa.date64()]
229 DATETIME_PYARROW_DTYPES = [
230 pa.timestamp(unit=unit, tz=tz)
231 for unit in ["s", "ms", "us", "ns"]
232 for tz in [None, "UTC", "US/Pacific", "US/Eastern"]
233 ]
234 TIMEDELTA_PYARROW_DTYPES = [pa.duration(unit) for unit in ["s", "ms", "us", "ns"]]
235
236 BOOL_PYARROW_DTYPES = [pa.bool_()]
237
238 # TODO: Add container like pyarrow types:
239 # https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions
240 ALL_PYARROW_DTYPES = (
241 ALL_INT_PYARROW_DTYPES
242 + FLOAT_PYARROW_DTYPES
243 + DECIMAL_PYARROW_DTYPES
244 + STRING_PYARROW_DTYPES
245 + BINARY_PYARROW_DTYPES
246 + TIME_PYARROW_DTYPES
247 + DATE_PYARROW_DTYPES
248 + DATETIME_PYARROW_DTYPES
249 + TIMEDELTA_PYARROW_DTYPES
250 + BOOL_PYARROW_DTYPES
251 )
252else:
253 FLOAT_PYARROW_DTYPES_STR_REPR = []
254 ALL_INT_PYARROW_DTYPES_STR_REPR = []
255 ALL_PYARROW_DTYPES = []
256
257
258EMPTY_STRING_PATTERN = re.compile("^$")
259
260
261def reset_display_options() -> None:
262 """
263 Reset the display options for printing and representing objects.
264 """
265 pd.reset_option("^display.", silent=True)
266
267
268# -----------------------------------------------------------------------------
269# Comparators
270
271
272def equalContents(arr1, arr2) -> bool:
273 """
274 Checks if the set of unique elements of arr1 and arr2 are equivalent.
275 """
276 return frozenset(arr1) == frozenset(arr2)
277
278
279def box_expected(expected, box_cls, transpose: bool = True):
280 """
281 Helper function to wrap the expected output of a test in a given box_class.
282
283 Parameters
284 ----------
285 expected : np.ndarray, Index, Series
286 box_cls : {Index, Series, DataFrame}
287
288 Returns
289 -------
290 subclass of box_cls
291 """
292 if box_cls is pd.array:
293 if isinstance(expected, RangeIndex):
294 # pd.array would return an IntegerArray
295 expected = PandasArray(np.asarray(expected._values))
296 else:
297 expected = pd.array(expected, copy=False)
298 elif box_cls is Index:
299 expected = Index(expected)
300 elif box_cls is Series:
301 expected = Series(expected)
302 elif box_cls is DataFrame:
303 expected = Series(expected).to_frame()
304 if transpose:
305 # for vector operations, we need a DataFrame to be a single-row,
306 # not a single-column, in order to operate against non-DataFrame
307 # vectors of the same length. But convert to two rows to avoid
308 # single-row special cases in datetime arithmetic
309 expected = expected.T
310 expected = pd.concat([expected] * 2, ignore_index=True)
311 elif box_cls is np.ndarray or box_cls is np.array:
312 expected = np.array(expected)
313 elif box_cls is to_array:
314 expected = to_array(expected)
315 else:
316 raise NotImplementedError(box_cls)
317 return expected
318
319
320def to_array(obj):
321 """
322 Similar to pd.array, but does not cast numpy dtypes to nullable dtypes.
323 """
324 # temporary implementation until we get pd.array in place
325 dtype = getattr(obj, "dtype", None)
326
327 if dtype is None:
328 return np.asarray(obj)
329
330 return extract_array(obj, extract_numpy=True)
331
332
333# -----------------------------------------------------------------------------
334# Others
335
336
337def getCols(k) -> str:
338 return string.ascii_uppercase[:k]
339
340
341# make index
342def makeStringIndex(k: int = 10, name=None) -> Index:
343 return Index(rands_array(nchars=10, size=k), name=name)
344
345
346def makeCategoricalIndex(
347 k: int = 10, n: int = 3, name=None, **kwargs
348) -> CategoricalIndex:
349 """make a length k index or n categories"""
350 x = rands_array(nchars=4, size=n, replace=False)
351 return CategoricalIndex(
352 Categorical.from_codes(np.arange(k) % n, categories=x), name=name, **kwargs
353 )
354
355
356def makeIntervalIndex(k: int = 10, name=None, **kwargs) -> IntervalIndex:
357 """make a length k IntervalIndex"""
358 x = np.linspace(0, 100, num=(k + 1))
359 return IntervalIndex.from_breaks(x, name=name, **kwargs)
360
361
362def makeBoolIndex(k: int = 10, name=None) -> Index:
363 if k == 1:
364 return Index([True], name=name)
365 elif k == 2:
366 return Index([False, True], name=name)
367 return Index([False, True] + [False] * (k - 2), name=name)
368
369
370def makeNumericIndex(k: int = 10, *, name=None, dtype: Dtype | None) -> Index:
371 dtype = pandas_dtype(dtype)
372 assert isinstance(dtype, np.dtype)
373
374 if is_integer_dtype(dtype):
375 values = np.arange(k, dtype=dtype)
376 if is_unsigned_integer_dtype(dtype):
377 values += 2 ** (dtype.itemsize * 8 - 1)
378 elif is_float_dtype(dtype):
379 values = np.random.random_sample(k) - np.random.random_sample(1)
380 values.sort()
381 values = values * (10 ** np.random.randint(0, 9))
382 else:
383 raise NotImplementedError(f"wrong dtype {dtype}")
384
385 return Index(values, dtype=dtype, name=name)
386
387
388def makeIntIndex(k: int = 10, *, name=None, dtype: Dtype = "int64") -> Index:
389 dtype = pandas_dtype(dtype)
390 if not is_signed_integer_dtype(dtype):
391 raise TypeError(f"Wrong dtype {dtype}")
392 return makeNumericIndex(k, name=name, dtype=dtype)
393
394
395def makeUIntIndex(k: int = 10, *, name=None, dtype: Dtype = "uint64") -> Index:
396 dtype = pandas_dtype(dtype)
397 if not is_unsigned_integer_dtype(dtype):
398 raise TypeError(f"Wrong dtype {dtype}")
399 return makeNumericIndex(k, name=name, dtype=dtype)
400
401
402def makeRangeIndex(k: int = 10, name=None, **kwargs) -> RangeIndex:
403 return RangeIndex(0, k, 1, name=name, **kwargs)
404
405
406def makeFloatIndex(k: int = 10, *, name=None, dtype: Dtype = "float64") -> Index:
407 dtype = pandas_dtype(dtype)
408 if not is_float_dtype(dtype):
409 raise TypeError(f"Wrong dtype {dtype}")
410 return makeNumericIndex(k, name=name, dtype=dtype)
411
412
413def makeDateIndex(
414 k: int = 10, freq: Frequency = "B", name=None, **kwargs
415) -> DatetimeIndex:
416 dt = datetime(2000, 1, 1)
417 dr = bdate_range(dt, periods=k, freq=freq, name=name)
418 return DatetimeIndex(dr, name=name, **kwargs)
419
420
421def makeTimedeltaIndex(
422 k: int = 10, freq: Frequency = "D", name=None, **kwargs
423) -> TimedeltaIndex:
424 return pd.timedelta_range(start="1 day", periods=k, freq=freq, name=name, **kwargs)
425
426
427def makePeriodIndex(k: int = 10, name=None, **kwargs) -> PeriodIndex:
428 dt = datetime(2000, 1, 1)
429 return pd.period_range(start=dt, periods=k, freq="B", name=name, **kwargs)
430
431
432def makeMultiIndex(k: int = 10, names=None, **kwargs):
433 N = (k // 2) + 1
434 rng = range(N)
435 mi = MultiIndex.from_product([("foo", "bar"), rng], names=names, **kwargs)
436 assert len(mi) >= k # GH#38795
437 return mi[:k]
438
439
440def index_subclass_makers_generator():
441 make_index_funcs = [
442 makeDateIndex,
443 makePeriodIndex,
444 makeTimedeltaIndex,
445 makeRangeIndex,
446 makeIntervalIndex,
447 makeCategoricalIndex,
448 makeMultiIndex,
449 ]
450 yield from make_index_funcs
451
452
453def all_timeseries_index_generator(k: int = 10) -> Iterable[Index]:
454 """
455 Generator which can be iterated over to get instances of all the classes
456 which represent time-series.
457
458 Parameters
459 ----------
460 k: length of each of the index instances
461 """
462 make_index_funcs: list[Callable[..., Index]] = [
463 makeDateIndex,
464 makePeriodIndex,
465 makeTimedeltaIndex,
466 ]
467 for make_index_func in make_index_funcs:
468 yield make_index_func(k=k)
469
470
471# make series
472def make_rand_series(name=None, dtype=np.float64) -> Series:
473 index = makeStringIndex(_N)
474 data = np.random.randn(_N)
475 with np.errstate(invalid="ignore"):
476 data = data.astype(dtype, copy=False)
477 return Series(data, index=index, name=name)
478
479
480def makeFloatSeries(name=None) -> Series:
481 return make_rand_series(name=name)
482
483
484def makeStringSeries(name=None) -> Series:
485 return make_rand_series(name=name)
486
487
488def makeObjectSeries(name=None) -> Series:
489 data = makeStringIndex(_N)
490 data = Index(data, dtype=object)
491 index = makeStringIndex(_N)
492 return Series(data, index=index, name=name)
493
494
495def getSeriesData() -> dict[str, Series]:
496 index = makeStringIndex(_N)
497 return {c: Series(np.random.randn(_N), index=index) for c in getCols(_K)}
498
499
500def makeTimeSeries(nper=None, freq: Frequency = "B", name=None) -> Series:
501 if nper is None:
502 nper = _N
503 return Series(
504 np.random.randn(nper), index=makeDateIndex(nper, freq=freq), name=name
505 )
506
507
508def makePeriodSeries(nper=None, name=None) -> Series:
509 if nper is None:
510 nper = _N
511 return Series(np.random.randn(nper), index=makePeriodIndex(nper), name=name)
512
513
514def getTimeSeriesData(nper=None, freq: Frequency = "B") -> dict[str, Series]:
515 return {c: makeTimeSeries(nper, freq) for c in getCols(_K)}
516
517
518def getPeriodData(nper=None) -> dict[str, Series]:
519 return {c: makePeriodSeries(nper) for c in getCols(_K)}
520
521
522# make frame
523def makeTimeDataFrame(nper=None, freq: Frequency = "B") -> DataFrame:
524 data = getTimeSeriesData(nper, freq)
525 return DataFrame(data)
526
527
528def makeDataFrame() -> DataFrame:
529 data = getSeriesData()
530 return DataFrame(data)
531
532
533def getMixedTypeDict():
534 index = Index(["a", "b", "c", "d", "e"])
535
536 data = {
537 "A": [0.0, 1.0, 2.0, 3.0, 4.0],
538 "B": [0.0, 1.0, 0.0, 1.0, 0.0],
539 "C": ["foo1", "foo2", "foo3", "foo4", "foo5"],
540 "D": bdate_range("1/1/2009", periods=5),
541 }
542
543 return index, data
544
545
546def makeMixedDataFrame() -> DataFrame:
547 return DataFrame(getMixedTypeDict()[1])
548
549
550def makePeriodFrame(nper=None) -> DataFrame:
551 data = getPeriodData(nper)
552 return DataFrame(data)
553
554
555def makeCustomIndex(
556 nentries,
557 nlevels,
558 prefix: str = "#",
559 names: bool | str | list[str] | None = False,
560 ndupe_l=None,
561 idx_type=None,
562) -> Index:
563 """
564 Create an index/multindex with given dimensions, levels, names, etc'
565
566 nentries - number of entries in index
567 nlevels - number of levels (> 1 produces multindex)
568 prefix - a string prefix for labels
569 names - (Optional), bool or list of strings. if True will use default
570 names, if false will use no names, if a list is given, the name of
571 each level in the index will be taken from the list.
572 ndupe_l - (Optional), list of ints, the number of rows for which the
573 label will repeated at the corresponding level, you can specify just
574 the first few, the rest will use the default ndupe_l of 1.
575 len(ndupe_l) <= nlevels.
576 idx_type - "i"/"f"/"s"/"dt"/"p"/"td".
577 If idx_type is not None, `idx_nlevels` must be 1.
578 "i"/"f" creates an integer/float index,
579 "s" creates a string
580 "dt" create a datetime index.
581 "td" create a datetime index.
582
583 if unspecified, string labels will be generated.
584 """
585 if ndupe_l is None:
586 ndupe_l = [1] * nlevels
587 assert is_sequence(ndupe_l) and len(ndupe_l) <= nlevels
588 assert names is None or names is False or names is True or len(names) is nlevels
589 assert idx_type is None or (
590 idx_type in ("i", "f", "s", "u", "dt", "p", "td") and nlevels == 1
591 )
592
593 if names is True:
594 # build default names
595 names = [prefix + str(i) for i in range(nlevels)]
596 if names is False:
597 # pass None to index constructor for no name
598 names = None
599
600 # make singleton case uniform
601 if isinstance(names, str) and nlevels == 1:
602 names = [names]
603
604 # specific 1D index type requested?
605 idx_func_dict: dict[str, Callable[..., Index]] = {
606 "i": makeIntIndex,
607 "f": makeFloatIndex,
608 "s": makeStringIndex,
609 "dt": makeDateIndex,
610 "td": makeTimedeltaIndex,
611 "p": makePeriodIndex,
612 }
613 idx_func = idx_func_dict.get(idx_type)
614 if idx_func:
615 idx = idx_func(nentries)
616 # but we need to fill in the name
617 if names:
618 idx.name = names[0]
619 return idx
620 elif idx_type is not None:
621 raise ValueError(
622 f"{repr(idx_type)} is not a legal value for `idx_type`, "
623 "use 'i'/'f'/'s'/'dt'/'p'/'td'."
624 )
625
626 if len(ndupe_l) < nlevels:
627 ndupe_l.extend([1] * (nlevels - len(ndupe_l)))
628 assert len(ndupe_l) == nlevels
629
630 assert all(x > 0 for x in ndupe_l)
631
632 list_of_lists = []
633 for i in range(nlevels):
634
635 def keyfunc(x):
636 numeric_tuple = re.sub(r"[^\d_]_?", "", x).split("_")
637 return [int(num) for num in numeric_tuple]
638
639 # build a list of lists to create the index from
640 div_factor = nentries // ndupe_l[i] + 1
641
642 # Deprecated since version 3.9: collections.Counter now supports []. See PEP 585
643 # and Generic Alias Type.
644 cnt: Counter[str] = collections.Counter()
645 for j in range(div_factor):
646 label = f"{prefix}_l{i}_g{j}"
647 cnt[label] = ndupe_l[i]
648 # cute Counter trick
649 result = sorted(cnt.elements(), key=keyfunc)[:nentries]
650 list_of_lists.append(result)
651
652 tuples = list(zip(*list_of_lists))
653
654 # convert tuples to index
655 if nentries == 1:
656 # we have a single level of tuples, i.e. a regular Index
657 name = None if names is None else names[0]
658 index = Index(tuples[0], name=name)
659 elif nlevels == 1:
660 name = None if names is None else names[0]
661 index = Index((x[0] for x in tuples), name=name)
662 else:
663 index = MultiIndex.from_tuples(tuples, names=names)
664 return index
665
666
667def makeCustomDataframe(
668 nrows,
669 ncols,
670 c_idx_names: bool | list[str] = True,
671 r_idx_names: bool | list[str] = True,
672 c_idx_nlevels: int = 1,
673 r_idx_nlevels: int = 1,
674 data_gen_f=None,
675 c_ndupe_l=None,
676 r_ndupe_l=None,
677 dtype=None,
678 c_idx_type=None,
679 r_idx_type=None,
680) -> DataFrame:
681 """
682 Create a DataFrame using supplied parameters.
683
684 Parameters
685 ----------
686 nrows, ncols - number of data rows/cols
687 c_idx_names, r_idx_names - False/True/list of strings, yields No names ,
688 default names or uses the provided names for the levels of the
689 corresponding index. You can provide a single string when
690 c_idx_nlevels ==1.
691 c_idx_nlevels - number of levels in columns index. > 1 will yield MultiIndex
692 r_idx_nlevels - number of levels in rows index. > 1 will yield MultiIndex
693 data_gen_f - a function f(row,col) which return the data value
694 at that position, the default generator used yields values of the form
695 "RxCy" based on position.
696 c_ndupe_l, r_ndupe_l - list of integers, determines the number
697 of duplicates for each label at a given level of the corresponding
698 index. The default `None` value produces a multiplicity of 1 across
699 all levels, i.e. a unique index. Will accept a partial list of length
700 N < idx_nlevels, for just the first N levels. If ndupe doesn't divide
701 nrows/ncol, the last label might have lower multiplicity.
702 dtype - passed to the DataFrame constructor as is, in case you wish to
703 have more control in conjunction with a custom `data_gen_f`
704 r_idx_type, c_idx_type - "i"/"f"/"s"/"dt"/"td".
705 If idx_type is not None, `idx_nlevels` must be 1.
706 "i"/"f" creates an integer/float index,
707 "s" creates a string index
708 "dt" create a datetime index.
709 "td" create a timedelta index.
710
711 if unspecified, string labels will be generated.
712
713 Examples
714 --------
715 # 5 row, 3 columns, default names on both, single index on both axis
716 >> makeCustomDataframe(5,3)
717
718 # make the data a random int between 1 and 100
719 >> mkdf(5,3,data_gen_f=lambda r,c:randint(1,100))
720
721 # 2-level multiindex on rows with each label duplicated
722 # twice on first level, default names on both axis, single
723 # index on both axis
724 >> a=makeCustomDataframe(5,3,r_idx_nlevels=2,r_ndupe_l=[2])
725
726 # DatetimeIndex on row, index with unicode labels on columns
727 # no names on either axis
728 >> a=makeCustomDataframe(5,3,c_idx_names=False,r_idx_names=False,
729 r_idx_type="dt",c_idx_type="u")
730
731 # 4-level multindex on rows with names provided, 2-level multindex
732 # on columns with default labels and default names.
733 >> a=makeCustomDataframe(5,3,r_idx_nlevels=4,
734 r_idx_names=["FEE","FIH","FOH","FUM"],
735 c_idx_nlevels=2)
736
737 >> a=mkdf(5,3,r_idx_nlevels=2,c_idx_nlevels=4)
738 """
739 assert c_idx_nlevels > 0
740 assert r_idx_nlevels > 0
741 assert r_idx_type is None or (
742 r_idx_type in ("i", "f", "s", "dt", "p", "td") and r_idx_nlevels == 1
743 )
744 assert c_idx_type is None or (
745 c_idx_type in ("i", "f", "s", "dt", "p", "td") and c_idx_nlevels == 1
746 )
747
748 columns = makeCustomIndex(
749 ncols,
750 nlevels=c_idx_nlevels,
751 prefix="C",
752 names=c_idx_names,
753 ndupe_l=c_ndupe_l,
754 idx_type=c_idx_type,
755 )
756 index = makeCustomIndex(
757 nrows,
758 nlevels=r_idx_nlevels,
759 prefix="R",
760 names=r_idx_names,
761 ndupe_l=r_ndupe_l,
762 idx_type=r_idx_type,
763 )
764
765 # by default, generate data based on location
766 if data_gen_f is None:
767 data_gen_f = lambda r, c: f"R{r}C{c}"
768
769 data = [[data_gen_f(r, c) for c in range(ncols)] for r in range(nrows)]
770
771 return DataFrame(data, index, columns, dtype=dtype)
772
773
774def _create_missing_idx(nrows, ncols, density: float, random_state=None):
775 if random_state is None:
776 random_state = np.random
777 else:
778 random_state = np.random.RandomState(random_state)
779
780 # below is cribbed from scipy.sparse
781 size = round((1 - density) * nrows * ncols)
782 # generate a few more to ensure unique values
783 min_rows = 5
784 fac = 1.02
785 extra_size = min(size + min_rows, fac * size)
786
787 def _gen_unique_rand(rng, _extra_size):
788 ind = rng.rand(int(_extra_size))
789 return np.unique(np.floor(ind * nrows * ncols))[:size]
790
791 ind = _gen_unique_rand(random_state, extra_size)
792 while ind.size < size:
793 extra_size *= 1.05
794 ind = _gen_unique_rand(random_state, extra_size)
795
796 j = np.floor(ind * 1.0 / nrows).astype(int)
797 i = (ind - j * nrows).astype(int)
798 return i.tolist(), j.tolist()
799
800
801def makeMissingDataframe(density: float = 0.9, random_state=None) -> DataFrame:
802 df = makeDataFrame()
803 i, j = _create_missing_idx(*df.shape, density=density, random_state=random_state)
804 df.iloc[i, j] = np.nan
805 return df
806
807
808class SubclassedSeries(Series):
809 _metadata = ["testattr", "name"]
810
811 @property
812 def _constructor(self):
813 # For testing, those properties return a generic callable, and not
814 # the actual class. In this case that is equivalent, but it is to
815 # ensure we don't rely on the property returning a class
816 # See https://github.com/pandas-dev/pandas/pull/46018 and
817 # https://github.com/pandas-dev/pandas/issues/32638 and linked issues
818 return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
819
820 @property
821 def _constructor_expanddim(self):
822 return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
823
824
825class SubclassedDataFrame(DataFrame):
826 _metadata = ["testattr"]
827
828 @property
829 def _constructor(self):
830 return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
831
832 @property
833 def _constructor_sliced(self):
834 return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
835
836
837class SubclassedCategorical(Categorical):
838 @property
839 def _constructor(self):
840 return SubclassedCategorical
841
842
843def _make_skipna_wrapper(alternative, skipna_alternative=None):
844 """
845 Create a function for calling on an array.
846
847 Parameters
848 ----------
849 alternative : function
850 The function to be called on the array with no NaNs.
851 Only used when 'skipna_alternative' is None.
852 skipna_alternative : function
853 The function to be called on the original array
854
855 Returns
856 -------
857 function
858 """
859 if skipna_alternative:
860
861 def skipna_wrapper(x):
862 return skipna_alternative(x.values)
863
864 else:
865
866 def skipna_wrapper(x):
867 nona = x.dropna()
868 if len(nona) == 0:
869 return np.nan
870 return alternative(nona)
871
872 return skipna_wrapper
873
874
875def convert_rows_list_to_csv_str(rows_list: list[str]) -> str:
876 """
877 Convert list of CSV rows to single CSV-formatted string for current OS.
878
879 This method is used for creating expected value of to_csv() method.
880
881 Parameters
882 ----------
883 rows_list : List[str]
884 Each element represents the row of csv.
885
886 Returns
887 -------
888 str
889 Expected output of to_csv() in current OS.
890 """
891 sep = os.linesep
892 return sep.join(rows_list) + sep
893
894
895def external_error_raised(expected_exception: type[Exception]) -> ContextManager:
896 """
897 Helper function to mark pytest.raises that have an external error message.
898
899 Parameters
900 ----------
901 expected_exception : Exception
902 Expected error to raise.
903
904 Returns
905 -------
906 Callable
907 Regular `pytest.raises` function with `match` equal to `None`.
908 """
909 import pytest
910
911 return pytest.raises(expected_exception, match=None)
912
913
914cython_table = pd.core.common._cython_table.items()
915
916
917def get_cython_table_params(ndframe, func_names_and_expected):
918 """
919 Combine frame, functions from com._cython_table
920 keys and expected result.
921
922 Parameters
923 ----------
924 ndframe : DataFrame or Series
925 func_names_and_expected : Sequence of two items
926 The first item is a name of a NDFrame method ('sum', 'prod') etc.
927 The second item is the expected return value.
928
929 Returns
930 -------
931 list
932 List of three items (DataFrame, function, expected result)
933 """
934 results = []
935 for func_name, expected in func_names_and_expected:
936 results.append((ndframe, func_name, expected))
937 results += [
938 (ndframe, func, expected)
939 for func, name in cython_table
940 if name == func_name
941 ]
942 return results
943
944
945def get_op_from_name(op_name: str) -> Callable:
946 """
947 The operator function for a given op name.
948
949 Parameters
950 ----------
951 op_name : str
952 The op name, in form of "add" or "__add__".
953
954 Returns
955 -------
956 function
957 A function performing the operation.
958 """
959 short_opname = op_name.strip("_")
960 try:
961 op = getattr(operator, short_opname)
962 except AttributeError:
963 # Assume it is the reverse operator
964 rop = getattr(operator, short_opname[1:])
965 op = lambda x, y: rop(y, x)
966
967 return op
968
969
970# -----------------------------------------------------------------------------
971# Indexing test helpers
972
973
974def getitem(x):
975 return x
976
977
978def setitem(x):
979 return x
980
981
982def loc(x):
983 return x.loc
984
985
986def iloc(x):
987 return x.iloc
988
989
990def at(x):
991 return x.at
992
993
994def iat(x):
995 return x.iat
996
997
998# -----------------------------------------------------------------------------
999
1000
1001def shares_memory(left, right) -> bool:
1002 """
1003 Pandas-compat for np.shares_memory.
1004 """
1005 if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
1006 return np.shares_memory(left, right)
1007 elif isinstance(left, np.ndarray):
1008 # Call with reversed args to get to unpacking logic below.
1009 return shares_memory(right, left)
1010
1011 if isinstance(left, RangeIndex):
1012 return False
1013 if isinstance(left, MultiIndex):
1014 return shares_memory(left._codes, right)
1015 if isinstance(left, (Index, Series)):
1016 return shares_memory(left._values, right)
1017
1018 if isinstance(left, NDArrayBackedExtensionArray):
1019 return shares_memory(left._ndarray, right)
1020 if isinstance(left, pd.core.arrays.SparseArray):
1021 return shares_memory(left.sp_values, right)
1022 if isinstance(left, pd.core.arrays.IntervalArray):
1023 return shares_memory(left._left, right) or shares_memory(left._right, right)
1024
1025 if isinstance(left, ExtensionArray) and left.dtype == "string[pyarrow]":
1026 # https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
1027 left = cast("ArrowExtensionArray", left)
1028 if isinstance(right, ExtensionArray) and right.dtype == "string[pyarrow]":
1029 right = cast("ArrowExtensionArray", right)
1030 left_pa_data = left._data
1031 right_pa_data = right._data
1032 left_buf1 = left_pa_data.chunk(0).buffers()[1]
1033 right_buf1 = right_pa_data.chunk(0).buffers()[1]
1034 return left_buf1 == right_buf1
1035
1036 if isinstance(left, BaseMaskedArray) and isinstance(right, BaseMaskedArray):
1037 # By convention, we'll say these share memory if they share *either*
1038 # the _data or the _mask
1039 return np.shares_memory(left._data, right._data) or np.shares_memory(
1040 left._mask, right._mask
1041 )
1042
1043 if isinstance(left, DataFrame) and len(left._mgr.arrays) == 1:
1044 arr = left._mgr.arrays[0]
1045 return shares_memory(arr, right)
1046
1047 raise NotImplementedError(type(left), type(right))
1048
1049
1050__all__ = [
1051 "ALL_INT_EA_DTYPES",
1052 "ALL_INT_NUMPY_DTYPES",
1053 "ALL_NUMPY_DTYPES",
1054 "ALL_REAL_NUMPY_DTYPES",
1055 "all_timeseries_index_generator",
1056 "assert_almost_equal",
1057 "assert_attr_equal",
1058 "assert_categorical_equal",
1059 "assert_class_equal",
1060 "assert_contains_all",
1061 "assert_copy",
1062 "assert_datetime_array_equal",
1063 "assert_dict_equal",
1064 "assert_equal",
1065 "assert_extension_array_equal",
1066 "assert_frame_equal",
1067 "assert_index_equal",
1068 "assert_indexing_slices_equivalent",
1069 "assert_interval_array_equal",
1070 "assert_is_sorted",
1071 "assert_is_valid_plot_return_object",
1072 "assert_metadata_equivalent",
1073 "assert_numpy_array_equal",
1074 "assert_period_array_equal",
1075 "assert_produces_warning",
1076 "assert_series_equal",
1077 "assert_sp_array_equal",
1078 "assert_timedelta_array_equal",
1079 "at",
1080 "BOOL_DTYPES",
1081 "box_expected",
1082 "BYTES_DTYPES",
1083 "can_set_locale",
1084 "close",
1085 "COMPLEX_DTYPES",
1086 "convert_rows_list_to_csv_str",
1087 "DATETIME64_DTYPES",
1088 "decompress_file",
1089 "EMPTY_STRING_PATTERN",
1090 "ENDIAN",
1091 "ensure_clean",
1092 "ensure_safe_environment_variables",
1093 "equalContents",
1094 "external_error_raised",
1095 "FLOAT_EA_DTYPES",
1096 "FLOAT_NUMPY_DTYPES",
1097 "getCols",
1098 "get_cython_table_params",
1099 "get_dtype",
1100 "getitem",
1101 "get_locales",
1102 "getMixedTypeDict",
1103 "get_obj",
1104 "get_op_from_name",
1105 "getPeriodData",
1106 "getSeriesData",
1107 "getTimeSeriesData",
1108 "iat",
1109 "iloc",
1110 "index_subclass_makers_generator",
1111 "loc",
1112 "makeBoolIndex",
1113 "makeCategoricalIndex",
1114 "makeCustomDataframe",
1115 "makeCustomIndex",
1116 "makeDataFrame",
1117 "makeDateIndex",
1118 "makeFloatIndex",
1119 "makeFloatSeries",
1120 "makeIntervalIndex",
1121 "makeIntIndex",
1122 "makeMissingDataframe",
1123 "makeMixedDataFrame",
1124 "makeMultiIndex",
1125 "makeNumericIndex",
1126 "makeObjectSeries",
1127 "makePeriodFrame",
1128 "makePeriodIndex",
1129 "makePeriodSeries",
1130 "make_rand_series",
1131 "makeRangeIndex",
1132 "makeStringIndex",
1133 "makeStringSeries",
1134 "makeTimeDataFrame",
1135 "makeTimedeltaIndex",
1136 "makeTimeSeries",
1137 "makeUIntIndex",
1138 "maybe_produces_warning",
1139 "NARROW_NP_DTYPES",
1140 "network",
1141 "NP_NAT_OBJECTS",
1142 "NULL_OBJECTS",
1143 "OBJECT_DTYPES",
1144 "raise_assert_detail",
1145 "rands",
1146 "reset_display_options",
1147 "raises_chained_assignment_error",
1148 "round_trip_localpath",
1149 "round_trip_pathlib",
1150 "round_trip_pickle",
1151 "setitem",
1152 "set_locale",
1153 "set_timezone",
1154 "shares_memory",
1155 "SIGNED_INT_EA_DTYPES",
1156 "SIGNED_INT_NUMPY_DTYPES",
1157 "STRING_DTYPES",
1158 "SubclassedCategorical",
1159 "SubclassedDataFrame",
1160 "SubclassedSeries",
1161 "TIMEDELTA64_DTYPES",
1162 "to_array",
1163 "UNSIGNED_INT_EA_DTYPES",
1164 "UNSIGNED_INT_NUMPY_DTYPES",
1165 "use_numexpr",
1166 "with_csv_dialect",
1167 "write_to_compressed",
1168]