1from __future__ import annotations
2
3import operator
4from typing import (
5 Literal,
6 cast,
7)
8
9import numpy as np
10
11from pandas._libs.missing import is_matching_na
12from pandas._libs.sparse import SparseIndex
13import pandas._libs.testing as _testing
14from pandas._libs.tslibs.np_datetime import compare_mismatched_resolutions
15
16from pandas.core.dtypes.common import (
17 is_bool,
18 is_categorical_dtype,
19 is_extension_array_dtype,
20 is_integer_dtype,
21 is_interval_dtype,
22 is_number,
23 is_numeric_dtype,
24 needs_i8_conversion,
25)
26from pandas.core.dtypes.dtypes import (
27 CategoricalDtype,
28 DatetimeTZDtype,
29 PandasDtype,
30)
31from pandas.core.dtypes.missing import array_equivalent
32
33import pandas as pd
34from pandas import (
35 Categorical,
36 DataFrame,
37 DatetimeIndex,
38 Index,
39 IntervalIndex,
40 MultiIndex,
41 PeriodIndex,
42 RangeIndex,
43 Series,
44 TimedeltaIndex,
45)
46from pandas.core.algorithms import take_nd
47from pandas.core.arrays import (
48 DatetimeArray,
49 ExtensionArray,
50 IntervalArray,
51 PeriodArray,
52 TimedeltaArray,
53)
54from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin
55from pandas.core.arrays.string_ import StringDtype
56from pandas.core.indexes.api import safe_sort_index
57
58from pandas.io.formats.printing import pprint_thing
59
60
61def assert_almost_equal(
62 left,
63 right,
64 check_dtype: bool | Literal["equiv"] = "equiv",
65 rtol: float = 1.0e-5,
66 atol: float = 1.0e-8,
67 **kwargs,
68) -> None:
69 """
70 Check that the left and right objects are approximately equal.
71
72 By approximately equal, we refer to objects that are numbers or that
73 contain numbers which may be equivalent to specific levels of precision.
74
75 Parameters
76 ----------
77 left : object
78 right : object
79 check_dtype : bool or {'equiv'}, default 'equiv'
80 Check dtype if both a and b are the same type. If 'equiv' is passed in,
81 then `RangeIndex` and `Index` with int64 dtype are also considered
82 equivalent when doing type checking.
83 rtol : float, default 1e-5
84 Relative tolerance.
85
86 .. versionadded:: 1.1.0
87 atol : float, default 1e-8
88 Absolute tolerance.
89
90 .. versionadded:: 1.1.0
91 """
92 if isinstance(left, Index):
93 assert_index_equal(
94 left,
95 right,
96 check_exact=False,
97 exact=check_dtype,
98 rtol=rtol,
99 atol=atol,
100 **kwargs,
101 )
102
103 elif isinstance(left, Series):
104 assert_series_equal(
105 left,
106 right,
107 check_exact=False,
108 check_dtype=check_dtype,
109 rtol=rtol,
110 atol=atol,
111 **kwargs,
112 )
113
114 elif isinstance(left, DataFrame):
115 assert_frame_equal(
116 left,
117 right,
118 check_exact=False,
119 check_dtype=check_dtype,
120 rtol=rtol,
121 atol=atol,
122 **kwargs,
123 )
124
125 else:
126 # Other sequences.
127 if check_dtype:
128 if is_number(left) and is_number(right):
129 # Do not compare numeric classes, like np.float64 and float.
130 pass
131 elif is_bool(left) and is_bool(right):
132 # Do not compare bool classes, like np.bool_ and bool.
133 pass
134 else:
135 if isinstance(left, np.ndarray) or isinstance(right, np.ndarray):
136 obj = "numpy array"
137 else:
138 obj = "Input"
139 assert_class_equal(left, right, obj=obj)
140
141 # if we have "equiv", this becomes True
142 _testing.assert_almost_equal(
143 left, right, check_dtype=bool(check_dtype), rtol=rtol, atol=atol, **kwargs
144 )
145
146
147def _check_isinstance(left, right, cls):
148 """
149 Helper method for our assert_* methods that ensures that
150 the two objects being compared have the right type before
151 proceeding with the comparison.
152
153 Parameters
154 ----------
155 left : The first object being compared.
156 right : The second object being compared.
157 cls : The class type to check against.
158
159 Raises
160 ------
161 AssertionError : Either `left` or `right` is not an instance of `cls`.
162 """
163 cls_name = cls.__name__
164
165 if not isinstance(left, cls):
166 raise AssertionError(
167 f"{cls_name} Expected type {cls}, found {type(left)} instead"
168 )
169 if not isinstance(right, cls):
170 raise AssertionError(
171 f"{cls_name} Expected type {cls}, found {type(right)} instead"
172 )
173
174
175def assert_dict_equal(left, right, compare_keys: bool = True) -> None:
176 _check_isinstance(left, right, dict)
177 _testing.assert_dict_equal(left, right, compare_keys=compare_keys)
178
179
180def assert_index_equal(
181 left: Index,
182 right: Index,
183 exact: bool | str = "equiv",
184 check_names: bool = True,
185 check_exact: bool = True,
186 check_categorical: bool = True,
187 check_order: bool = True,
188 rtol: float = 1.0e-5,
189 atol: float = 1.0e-8,
190 obj: str = "Index",
191) -> None:
192 """
193 Check that left and right Index are equal.
194
195 Parameters
196 ----------
197 left : Index
198 right : Index
199 exact : bool or {'equiv'}, default 'equiv'
200 Whether to check the Index class, dtype and inferred_type
201 are identical. If 'equiv', then RangeIndex can be substituted for
202 Index with an int64 dtype as well.
203 check_names : bool, default True
204 Whether to check the names attribute.
205 check_exact : bool, default True
206 Whether to compare number exactly.
207 check_categorical : bool, default True
208 Whether to compare internal Categorical exactly.
209 check_order : bool, default True
210 Whether to compare the order of index entries as well as their values.
211 If True, both indexes must contain the same elements, in the same order.
212 If False, both indexes must contain the same elements, but in any order.
213
214 .. versionadded:: 1.2.0
215 rtol : float, default 1e-5
216 Relative tolerance. Only used when check_exact is False.
217
218 .. versionadded:: 1.1.0
219 atol : float, default 1e-8
220 Absolute tolerance. Only used when check_exact is False.
221
222 .. versionadded:: 1.1.0
223 obj : str, default 'Index'
224 Specify object name being compared, internally used to show appropriate
225 assertion message.
226
227 Examples
228 --------
229 >>> from pandas import testing as tm
230 >>> a = pd.Index([1, 2, 3])
231 >>> b = pd.Index([1, 2, 3])
232 >>> tm.assert_index_equal(a, b)
233 """
234 __tracebackhide__ = True
235
236 def _check_types(left, right, obj: str = "Index") -> None:
237 if not exact:
238 return
239
240 assert_class_equal(left, right, exact=exact, obj=obj)
241 assert_attr_equal("inferred_type", left, right, obj=obj)
242
243 # Skip exact dtype checking when `check_categorical` is False
244 if is_categorical_dtype(left.dtype) and is_categorical_dtype(right.dtype):
245 if check_categorical:
246 assert_attr_equal("dtype", left, right, obj=obj)
247 assert_index_equal(left.categories, right.categories, exact=exact)
248 return
249
250 assert_attr_equal("dtype", left, right, obj=obj)
251
252 def _get_ilevel_values(index, level):
253 # accept level number only
254 unique = index.levels[level]
255 level_codes = index.codes[level]
256 filled = take_nd(unique._values, level_codes, fill_value=unique._na_value)
257 return unique._shallow_copy(filled, name=index.names[level])
258
259 # instance validation
260 _check_isinstance(left, right, Index)
261
262 # class / dtype comparison
263 _check_types(left, right, obj=obj)
264
265 # level comparison
266 if left.nlevels != right.nlevels:
267 msg1 = f"{obj} levels are different"
268 msg2 = f"{left.nlevels}, {left}"
269 msg3 = f"{right.nlevels}, {right}"
270 raise_assert_detail(obj, msg1, msg2, msg3)
271
272 # length comparison
273 if len(left) != len(right):
274 msg1 = f"{obj} length are different"
275 msg2 = f"{len(left)}, {left}"
276 msg3 = f"{len(right)}, {right}"
277 raise_assert_detail(obj, msg1, msg2, msg3)
278
279 # If order doesn't matter then sort the index entries
280 if not check_order:
281 left = safe_sort_index(left)
282 right = safe_sort_index(right)
283
284 # MultiIndex special comparison for little-friendly error messages
285 if left.nlevels > 1:
286 left = cast(MultiIndex, left)
287 right = cast(MultiIndex, right)
288
289 for level in range(left.nlevels):
290 # cannot use get_level_values here because it can change dtype
291 llevel = _get_ilevel_values(left, level)
292 rlevel = _get_ilevel_values(right, level)
293
294 lobj = f"MultiIndex level [{level}]"
295 assert_index_equal(
296 llevel,
297 rlevel,
298 exact=exact,
299 check_names=check_names,
300 check_exact=check_exact,
301 rtol=rtol,
302 atol=atol,
303 obj=lobj,
304 )
305 # get_level_values may change dtype
306 _check_types(left.levels[level], right.levels[level], obj=obj)
307
308 # skip exact index checking when `check_categorical` is False
309 if check_exact and check_categorical:
310 if not left.equals(right):
311 mismatch = left._values != right._values
312
313 if is_extension_array_dtype(mismatch):
314 mismatch = cast("ExtensionArray", mismatch).fillna(True)
315
316 diff = np.sum(mismatch.astype(int)) * 100.0 / len(left)
317 msg = f"{obj} values are different ({np.round(diff, 5)} %)"
318 raise_assert_detail(obj, msg, left, right)
319 else:
320 # if we have "equiv", this becomes True
321 exact_bool = bool(exact)
322 _testing.assert_almost_equal(
323 left.values,
324 right.values,
325 rtol=rtol,
326 atol=atol,
327 check_dtype=exact_bool,
328 obj=obj,
329 lobj=left,
330 robj=right,
331 )
332
333 # metadata comparison
334 if check_names:
335 assert_attr_equal("names", left, right, obj=obj)
336 if isinstance(left, PeriodIndex) or isinstance(right, PeriodIndex):
337 assert_attr_equal("freq", left, right, obj=obj)
338 if isinstance(left, IntervalIndex) or isinstance(right, IntervalIndex):
339 assert_interval_array_equal(left._values, right._values)
340
341 if check_categorical:
342 if is_categorical_dtype(left.dtype) or is_categorical_dtype(right.dtype):
343 assert_categorical_equal(left._values, right._values, obj=f"{obj} category")
344
345
346def assert_class_equal(
347 left, right, exact: bool | str = True, obj: str = "Input"
348) -> None:
349 """
350 Checks classes are equal.
351 """
352 __tracebackhide__ = True
353
354 def repr_class(x):
355 if isinstance(x, Index):
356 # return Index as it is to include values in the error message
357 return x
358
359 return type(x).__name__
360
361 def is_class_equiv(idx: Index) -> bool:
362 """Classes that are a RangeIndex (sub-)instance or exactly an `Index` .
363
364 This only checks class equivalence. There is a separate check that the
365 dtype is int64.
366 """
367 return type(idx) is Index or isinstance(idx, RangeIndex)
368
369 if type(left) == type(right):
370 return
371
372 if exact == "equiv":
373 if is_class_equiv(left) and is_class_equiv(right):
374 return
375
376 msg = f"{obj} classes are different"
377 raise_assert_detail(obj, msg, repr_class(left), repr_class(right))
378
379
380def assert_attr_equal(attr: str, left, right, obj: str = "Attributes") -> None:
381 """
382 Check attributes are equal. Both objects must have attribute.
383
384 Parameters
385 ----------
386 attr : str
387 Attribute name being compared.
388 left : object
389 right : object
390 obj : str, default 'Attributes'
391 Specify object name being compared, internally used to show appropriate
392 assertion message
393 """
394 __tracebackhide__ = True
395
396 left_attr = getattr(left, attr)
397 right_attr = getattr(right, attr)
398
399 if left_attr is right_attr or is_matching_na(left_attr, right_attr):
400 # e.g. both np.nan, both NaT, both pd.NA, ...
401 return None
402
403 try:
404 result = left_attr == right_attr
405 except TypeError:
406 # datetimetz on rhs may raise TypeError
407 result = False
408 if (left_attr is pd.NA) ^ (right_attr is pd.NA):
409 result = False
410 elif not isinstance(result, bool):
411 result = result.all()
412
413 if not result:
414 msg = f'Attribute "{attr}" are different'
415 raise_assert_detail(obj, msg, left_attr, right_attr)
416 return None
417
418
419def assert_is_valid_plot_return_object(objs) -> None:
420 import matplotlib.pyplot as plt
421
422 if isinstance(objs, (Series, np.ndarray)):
423 for el in objs.ravel():
424 msg = (
425 "one of 'objs' is not a matplotlib Axes instance, "
426 f"type encountered {repr(type(el).__name__)}"
427 )
428 assert isinstance(el, (plt.Axes, dict)), msg
429 else:
430 msg = (
431 "objs is neither an ndarray of Artist instances nor a single "
432 "ArtistArtist instance, tuple, or dict, 'objs' is a "
433 f"{repr(type(objs).__name__)}"
434 )
435 assert isinstance(objs, (plt.Artist, tuple, dict)), msg
436
437
438def assert_is_sorted(seq) -> None:
439 """Assert that the sequence is sorted."""
440 if isinstance(seq, (Index, Series)):
441 seq = seq.values
442 # sorting does not change precisions
443 assert_numpy_array_equal(seq, np.sort(np.array(seq)))
444
445
446def assert_categorical_equal(
447 left,
448 right,
449 check_dtype: bool = True,
450 check_category_order: bool = True,
451 obj: str = "Categorical",
452) -> None:
453 """
454 Test that Categoricals are equivalent.
455
456 Parameters
457 ----------
458 left : Categorical
459 right : Categorical
460 check_dtype : bool, default True
461 Check that integer dtype of the codes are the same.
462 check_category_order : bool, default True
463 Whether the order of the categories should be compared, which
464 implies identical integer codes. If False, only the resulting
465 values are compared. The ordered attribute is
466 checked regardless.
467 obj : str, default 'Categorical'
468 Specify object name being compared, internally used to show appropriate
469 assertion message.
470 """
471 _check_isinstance(left, right, Categorical)
472
473 exact: bool | str
474 if isinstance(left.categories, RangeIndex) or isinstance(
475 right.categories, RangeIndex
476 ):
477 exact = "equiv"
478 else:
479 # We still want to require exact matches for Index
480 exact = True
481
482 if check_category_order:
483 assert_index_equal(
484 left.categories, right.categories, obj=f"{obj}.categories", exact=exact
485 )
486 assert_numpy_array_equal(
487 left.codes, right.codes, check_dtype=check_dtype, obj=f"{obj}.codes"
488 )
489 else:
490 try:
491 lc = left.categories.sort_values()
492 rc = right.categories.sort_values()
493 except TypeError:
494 # e.g. '<' not supported between instances of 'int' and 'str'
495 lc, rc = left.categories, right.categories
496 assert_index_equal(lc, rc, obj=f"{obj}.categories", exact=exact)
497 assert_index_equal(
498 left.categories.take(left.codes),
499 right.categories.take(right.codes),
500 obj=f"{obj}.values",
501 exact=exact,
502 )
503
504 assert_attr_equal("ordered", left, right, obj=obj)
505
506
507def assert_interval_array_equal(
508 left, right, exact: bool | Literal["equiv"] = "equiv", obj: str = "IntervalArray"
509) -> None:
510 """
511 Test that two IntervalArrays are equivalent.
512
513 Parameters
514 ----------
515 left, right : IntervalArray
516 The IntervalArrays to compare.
517 exact : bool or {'equiv'}, default 'equiv'
518 Whether to check the Index class, dtype and inferred_type
519 are identical. If 'equiv', then RangeIndex can be substituted for
520 Index with an int64 dtype as well.
521 obj : str, default 'IntervalArray'
522 Specify object name being compared, internally used to show appropriate
523 assertion message
524 """
525 _check_isinstance(left, right, IntervalArray)
526
527 kwargs = {}
528 if left._left.dtype.kind in ["m", "M"]:
529 # We have a DatetimeArray or TimedeltaArray
530 kwargs["check_freq"] = False
531
532 assert_equal(left._left, right._left, obj=f"{obj}.left", **kwargs)
533 assert_equal(left._right, right._right, obj=f"{obj}.left", **kwargs)
534
535 assert_attr_equal("closed", left, right, obj=obj)
536
537
538def assert_period_array_equal(left, right, obj: str = "PeriodArray") -> None:
539 _check_isinstance(left, right, PeriodArray)
540
541 assert_numpy_array_equal(left._ndarray, right._ndarray, obj=f"{obj}._ndarray")
542 assert_attr_equal("freq", left, right, obj=obj)
543
544
545def assert_datetime_array_equal(
546 left, right, obj: str = "DatetimeArray", check_freq: bool = True
547) -> None:
548 __tracebackhide__ = True
549 _check_isinstance(left, right, DatetimeArray)
550
551 assert_numpy_array_equal(left._ndarray, right._ndarray, obj=f"{obj}._ndarray")
552 if check_freq:
553 assert_attr_equal("freq", left, right, obj=obj)
554 assert_attr_equal("tz", left, right, obj=obj)
555
556
557def assert_timedelta_array_equal(
558 left, right, obj: str = "TimedeltaArray", check_freq: bool = True
559) -> None:
560 __tracebackhide__ = True
561 _check_isinstance(left, right, TimedeltaArray)
562 assert_numpy_array_equal(left._ndarray, right._ndarray, obj=f"{obj}._ndarray")
563 if check_freq:
564 assert_attr_equal("freq", left, right, obj=obj)
565
566
567def raise_assert_detail(
568 obj, message, left, right, diff=None, first_diff=None, index_values=None
569):
570 __tracebackhide__ = True
571
572 msg = f"""{obj} are different
573
574{message}"""
575
576 if isinstance(index_values, np.ndarray):
577 msg += f"\n[index]: {pprint_thing(index_values)}"
578
579 if isinstance(left, np.ndarray):
580 left = pprint_thing(left)
581 elif isinstance(left, (CategoricalDtype, PandasDtype, StringDtype)):
582 left = repr(left)
583
584 if isinstance(right, np.ndarray):
585 right = pprint_thing(right)
586 elif isinstance(right, (CategoricalDtype, PandasDtype, StringDtype)):
587 right = repr(right)
588
589 msg += f"""
590[left]: {left}
591[right]: {right}"""
592
593 if diff is not None:
594 msg += f"\n[diff]: {diff}"
595
596 if first_diff is not None:
597 msg += f"\n{first_diff}"
598
599 raise AssertionError(msg)
600
601
602def assert_numpy_array_equal(
603 left,
604 right,
605 strict_nan: bool = False,
606 check_dtype: bool | Literal["equiv"] = True,
607 err_msg=None,
608 check_same=None,
609 obj: str = "numpy array",
610 index_values=None,
611) -> None:
612 """
613 Check that 'np.ndarray' is equivalent.
614
615 Parameters
616 ----------
617 left, right : numpy.ndarray or iterable
618 The two arrays to be compared.
619 strict_nan : bool, default False
620 If True, consider NaN and None to be different.
621 check_dtype : bool, default True
622 Check dtype if both a and b are np.ndarray.
623 err_msg : str, default None
624 If provided, used as assertion message.
625 check_same : None|'copy'|'same', default None
626 Ensure left and right refer/do not refer to the same memory area.
627 obj : str, default 'numpy array'
628 Specify object name being compared, internally used to show appropriate
629 assertion message.
630 index_values : numpy.ndarray, default None
631 optional index (shared by both left and right), used in output.
632 """
633 __tracebackhide__ = True
634
635 # instance validation
636 # Show a detailed error message when classes are different
637 assert_class_equal(left, right, obj=obj)
638 # both classes must be an np.ndarray
639 _check_isinstance(left, right, np.ndarray)
640
641 def _get_base(obj):
642 return obj.base if getattr(obj, "base", None) is not None else obj
643
644 left_base = _get_base(left)
645 right_base = _get_base(right)
646
647 if check_same == "same":
648 if left_base is not right_base:
649 raise AssertionError(f"{repr(left_base)} is not {repr(right_base)}")
650 elif check_same == "copy":
651 if left_base is right_base:
652 raise AssertionError(f"{repr(left_base)} is {repr(right_base)}")
653
654 def _raise(left, right, err_msg):
655 if err_msg is None:
656 if left.shape != right.shape:
657 raise_assert_detail(
658 obj, f"{obj} shapes are different", left.shape, right.shape
659 )
660
661 diff = 0
662 for left_arr, right_arr in zip(left, right):
663 # count up differences
664 if not array_equivalent(left_arr, right_arr, strict_nan=strict_nan):
665 diff += 1
666
667 diff = diff * 100.0 / left.size
668 msg = f"{obj} values are different ({np.round(diff, 5)} %)"
669 raise_assert_detail(obj, msg, left, right, index_values=index_values)
670
671 raise AssertionError(err_msg)
672
673 # compare shape and values
674 if not array_equivalent(left, right, strict_nan=strict_nan):
675 _raise(left, right, err_msg)
676
677 if check_dtype:
678 if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
679 assert_attr_equal("dtype", left, right, obj=obj)
680
681
682def assert_extension_array_equal(
683 left,
684 right,
685 check_dtype: bool | Literal["equiv"] = True,
686 index_values=None,
687 check_exact: bool = False,
688 rtol: float = 1.0e-5,
689 atol: float = 1.0e-8,
690 obj: str = "ExtensionArray",
691) -> None:
692 """
693 Check that left and right ExtensionArrays are equal.
694
695 Parameters
696 ----------
697 left, right : ExtensionArray
698 The two arrays to compare.
699 check_dtype : bool, default True
700 Whether to check if the ExtensionArray dtypes are identical.
701 index_values : numpy.ndarray, default None
702 Optional index (shared by both left and right), used in output.
703 check_exact : bool, default False
704 Whether to compare number exactly.
705 rtol : float, default 1e-5
706 Relative tolerance. Only used when check_exact is False.
707
708 .. versionadded:: 1.1.0
709 atol : float, default 1e-8
710 Absolute tolerance. Only used when check_exact is False.
711
712 .. versionadded:: 1.1.0
713 obj : str, default 'ExtensionArray'
714 Specify object name being compared, internally used to show appropriate
715 assertion message.
716
717 .. versionadded:: 2.0.0
718
719 Notes
720 -----
721 Missing values are checked separately from valid values.
722 A mask of missing values is computed for each and checked to match.
723 The remaining all-valid values are cast to object dtype and checked.
724
725 Examples
726 --------
727 >>> from pandas import testing as tm
728 >>> a = pd.Series([1, 2, 3, 4])
729 >>> b, c = a.array, a.array
730 >>> tm.assert_extension_array_equal(b, c)
731 """
732 assert isinstance(left, ExtensionArray), "left is not an ExtensionArray"
733 assert isinstance(right, ExtensionArray), "right is not an ExtensionArray"
734 if check_dtype:
735 assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}")
736
737 if (
738 isinstance(left, DatetimeLikeArrayMixin)
739 and isinstance(right, DatetimeLikeArrayMixin)
740 and type(right) == type(left)
741 ):
742 # GH 52449
743 if not check_dtype and left.dtype.kind in "mM":
744 if not isinstance(left.dtype, np.dtype):
745 l_unit = cast(DatetimeTZDtype, left.dtype).unit
746 else:
747 l_unit = np.datetime_data(left.dtype)[0]
748 if not isinstance(right.dtype, np.dtype):
749 r_unit = cast(DatetimeTZDtype, left.dtype).unit
750 else:
751 r_unit = np.datetime_data(right.dtype)[0]
752 if (
753 l_unit != r_unit
754 and compare_mismatched_resolutions(
755 left._ndarray, right._ndarray, operator.eq
756 ).all()
757 ):
758 return
759 # Avoid slow object-dtype comparisons
760 # np.asarray for case where we have a np.MaskedArray
761 assert_numpy_array_equal(
762 np.asarray(left.asi8),
763 np.asarray(right.asi8),
764 index_values=index_values,
765 obj=obj,
766 )
767 return
768
769 left_na = np.asarray(left.isna())
770 right_na = np.asarray(right.isna())
771 assert_numpy_array_equal(
772 left_na, right_na, obj=f"{obj} NA mask", index_values=index_values
773 )
774
775 left_valid = left[~left_na].to_numpy(dtype=object)
776 right_valid = right[~right_na].to_numpy(dtype=object)
777 if check_exact:
778 assert_numpy_array_equal(
779 left_valid, right_valid, obj=obj, index_values=index_values
780 )
781 else:
782 _testing.assert_almost_equal(
783 left_valid,
784 right_valid,
785 check_dtype=bool(check_dtype),
786 rtol=rtol,
787 atol=atol,
788 obj=obj,
789 index_values=index_values,
790 )
791
792
793# This could be refactored to use the NDFrame.equals method
794def assert_series_equal(
795 left,
796 right,
797 check_dtype: bool | Literal["equiv"] = True,
798 check_index_type: bool | Literal["equiv"] = "equiv",
799 check_series_type: bool = True,
800 check_names: bool = True,
801 check_exact: bool = False,
802 check_datetimelike_compat: bool = False,
803 check_categorical: bool = True,
804 check_category_order: bool = True,
805 check_freq: bool = True,
806 check_flags: bool = True,
807 rtol: float = 1.0e-5,
808 atol: float = 1.0e-8,
809 obj: str = "Series",
810 *,
811 check_index: bool = True,
812 check_like: bool = False,
813) -> None:
814 """
815 Check that left and right Series are equal.
816
817 Parameters
818 ----------
819 left : Series
820 right : Series
821 check_dtype : bool, default True
822 Whether to check the Series dtype is identical.
823 check_index_type : bool or {'equiv'}, default 'equiv'
824 Whether to check the Index class, dtype and inferred_type
825 are identical.
826 check_series_type : bool, default True
827 Whether to check the Series class is identical.
828 check_names : bool, default True
829 Whether to check the Series and Index names attribute.
830 check_exact : bool, default False
831 Whether to compare number exactly.
832 check_datetimelike_compat : bool, default False
833 Compare datetime-like which is comparable ignoring dtype.
834 check_categorical : bool, default True
835 Whether to compare internal Categorical exactly.
836 check_category_order : bool, default True
837 Whether to compare category order of internal Categoricals.
838
839 .. versionadded:: 1.0.2
840 check_freq : bool, default True
841 Whether to check the `freq` attribute on a DatetimeIndex or TimedeltaIndex.
842
843 .. versionadded:: 1.1.0
844 check_flags : bool, default True
845 Whether to check the `flags` attribute.
846
847 .. versionadded:: 1.2.0
848
849 rtol : float, default 1e-5
850 Relative tolerance. Only used when check_exact is False.
851
852 .. versionadded:: 1.1.0
853 atol : float, default 1e-8
854 Absolute tolerance. Only used when check_exact is False.
855
856 .. versionadded:: 1.1.0
857 obj : str, default 'Series'
858 Specify object name being compared, internally used to show appropriate
859 assertion message.
860 check_index : bool, default True
861 Whether to check index equivalence. If False, then compare only values.
862
863 .. versionadded:: 1.3.0
864 check_like : bool, default False
865 If True, ignore the order of the index. Must be False if check_index is False.
866 Note: same labels must be with the same data.
867
868 .. versionadded:: 1.5.0
869
870 Examples
871 --------
872 >>> from pandas import testing as tm
873 >>> a = pd.Series([1, 2, 3, 4])
874 >>> b = pd.Series([1, 2, 3, 4])
875 >>> tm.assert_series_equal(a, b)
876 """
877 __tracebackhide__ = True
878
879 if not check_index and check_like:
880 raise ValueError("check_like must be False if check_index is False")
881
882 # instance validation
883 _check_isinstance(left, right, Series)
884
885 if check_series_type:
886 assert_class_equal(left, right, obj=obj)
887
888 # length comparison
889 if len(left) != len(right):
890 msg1 = f"{len(left)}, {left.index}"
891 msg2 = f"{len(right)}, {right.index}"
892 raise_assert_detail(obj, "Series length are different", msg1, msg2)
893
894 if check_flags:
895 assert left.flags == right.flags, f"{repr(left.flags)} != {repr(right.flags)}"
896
897 if check_index:
898 # GH #38183
899 assert_index_equal(
900 left.index,
901 right.index,
902 exact=check_index_type,
903 check_names=check_names,
904 check_exact=check_exact,
905 check_categorical=check_categorical,
906 check_order=not check_like,
907 rtol=rtol,
908 atol=atol,
909 obj=f"{obj}.index",
910 )
911
912 if check_like:
913 left = left.reindex_like(right)
914
915 if check_freq and isinstance(left.index, (DatetimeIndex, TimedeltaIndex)):
916 lidx = left.index
917 ridx = right.index
918 assert lidx.freq == ridx.freq, (lidx.freq, ridx.freq)
919
920 if check_dtype:
921 # We want to skip exact dtype checking when `check_categorical`
922 # is False. We'll still raise if only one is a `Categorical`,
923 # regardless of `check_categorical`
924 if (
925 isinstance(left.dtype, CategoricalDtype)
926 and isinstance(right.dtype, CategoricalDtype)
927 and not check_categorical
928 ):
929 pass
930 else:
931 assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}")
932
933 if check_exact and is_numeric_dtype(left.dtype) and is_numeric_dtype(right.dtype):
934 left_values = left._values
935 right_values = right._values
936 # Only check exact if dtype is numeric
937 if isinstance(left_values, ExtensionArray) and isinstance(
938 right_values, ExtensionArray
939 ):
940 assert_extension_array_equal(
941 left_values,
942 right_values,
943 check_dtype=check_dtype,
944 index_values=np.asarray(left.index),
945 obj=str(obj),
946 )
947 else:
948 assert_numpy_array_equal(
949 left_values,
950 right_values,
951 check_dtype=check_dtype,
952 obj=str(obj),
953 index_values=np.asarray(left.index),
954 )
955 elif check_datetimelike_compat and (
956 needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype)
957 ):
958 # we want to check only if we have compat dtypes
959 # e.g. integer and M|m are NOT compat, but we can simply check
960 # the values in that case
961
962 # datetimelike may have different objects (e.g. datetime.datetime
963 # vs Timestamp) but will compare equal
964 if not Index(left._values).equals(Index(right._values)):
965 msg = (
966 f"[datetimelike_compat=True] {left._values} "
967 f"is not equal to {right._values}."
968 )
969 raise AssertionError(msg)
970 elif is_interval_dtype(left.dtype) and is_interval_dtype(right.dtype):
971 assert_interval_array_equal(left.array, right.array)
972 elif isinstance(left.dtype, CategoricalDtype) or isinstance(
973 right.dtype, CategoricalDtype
974 ):
975 _testing.assert_almost_equal(
976 left._values,
977 right._values,
978 rtol=rtol,
979 atol=atol,
980 check_dtype=bool(check_dtype),
981 obj=str(obj),
982 index_values=np.asarray(left.index),
983 )
984 elif is_extension_array_dtype(left.dtype) and is_extension_array_dtype(right.dtype):
985 assert_extension_array_equal(
986 left._values,
987 right._values,
988 rtol=rtol,
989 atol=atol,
990 check_dtype=check_dtype,
991 index_values=np.asarray(left.index),
992 obj=str(obj),
993 )
994 elif is_extension_array_dtype_and_needs_i8_conversion(
995 left.dtype, right.dtype
996 ) or is_extension_array_dtype_and_needs_i8_conversion(right.dtype, left.dtype):
997 assert_extension_array_equal(
998 left._values,
999 right._values,
1000 check_dtype=check_dtype,
1001 index_values=np.asarray(left.index),
1002 obj=str(obj),
1003 )
1004 elif needs_i8_conversion(left.dtype) and needs_i8_conversion(right.dtype):
1005 # DatetimeArray or TimedeltaArray
1006 assert_extension_array_equal(
1007 left._values,
1008 right._values,
1009 check_dtype=check_dtype,
1010 index_values=np.asarray(left.index),
1011 obj=str(obj),
1012 )
1013 else:
1014 _testing.assert_almost_equal(
1015 left._values,
1016 right._values,
1017 rtol=rtol,
1018 atol=atol,
1019 check_dtype=bool(check_dtype),
1020 obj=str(obj),
1021 index_values=np.asarray(left.index),
1022 )
1023
1024 # metadata comparison
1025 if check_names:
1026 assert_attr_equal("name", left, right, obj=obj)
1027
1028 if check_categorical:
1029 if isinstance(left.dtype, CategoricalDtype) or isinstance(
1030 right.dtype, CategoricalDtype
1031 ):
1032 assert_categorical_equal(
1033 left._values,
1034 right._values,
1035 obj=f"{obj} category",
1036 check_category_order=check_category_order,
1037 )
1038
1039
1040# This could be refactored to use the NDFrame.equals method
1041def assert_frame_equal(
1042 left,
1043 right,
1044 check_dtype: bool | Literal["equiv"] = True,
1045 check_index_type: bool | Literal["equiv"] = "equiv",
1046 check_column_type: bool | Literal["equiv"] = "equiv",
1047 check_frame_type: bool = True,
1048 check_names: bool = True,
1049 by_blocks: bool = False,
1050 check_exact: bool = False,
1051 check_datetimelike_compat: bool = False,
1052 check_categorical: bool = True,
1053 check_like: bool = False,
1054 check_freq: bool = True,
1055 check_flags: bool = True,
1056 rtol: float = 1.0e-5,
1057 atol: float = 1.0e-8,
1058 obj: str = "DataFrame",
1059) -> None:
1060 """
1061 Check that left and right DataFrame are equal.
1062
1063 This function is intended to compare two DataFrames and output any
1064 differences. It is mostly intended for use in unit tests.
1065 Additional parameters allow varying the strictness of the
1066 equality checks performed.
1067
1068 Parameters
1069 ----------
1070 left : DataFrame
1071 First DataFrame to compare.
1072 right : DataFrame
1073 Second DataFrame to compare.
1074 check_dtype : bool, default True
1075 Whether to check the DataFrame dtype is identical.
1076 check_index_type : bool or {'equiv'}, default 'equiv'
1077 Whether to check the Index class, dtype and inferred_type
1078 are identical.
1079 check_column_type : bool or {'equiv'}, default 'equiv'
1080 Whether to check the columns class, dtype and inferred_type
1081 are identical. Is passed as the ``exact`` argument of
1082 :func:`assert_index_equal`.
1083 check_frame_type : bool, default True
1084 Whether to check the DataFrame class is identical.
1085 check_names : bool, default True
1086 Whether to check that the `names` attribute for both the `index`
1087 and `column` attributes of the DataFrame is identical.
1088 by_blocks : bool, default False
1089 Specify how to compare internal data. If False, compare by columns.
1090 If True, compare by blocks.
1091 check_exact : bool, default False
1092 Whether to compare number exactly.
1093 check_datetimelike_compat : bool, default False
1094 Compare datetime-like which is comparable ignoring dtype.
1095 check_categorical : bool, default True
1096 Whether to compare internal Categorical exactly.
1097 check_like : bool, default False
1098 If True, ignore the order of index & columns.
1099 Note: index labels must match their respective rows
1100 (same as in columns) - same labels must be with the same data.
1101 check_freq : bool, default True
1102 Whether to check the `freq` attribute on a DatetimeIndex or TimedeltaIndex.
1103
1104 .. versionadded:: 1.1.0
1105 check_flags : bool, default True
1106 Whether to check the `flags` attribute.
1107 rtol : float, default 1e-5
1108 Relative tolerance. Only used when check_exact is False.
1109
1110 .. versionadded:: 1.1.0
1111 atol : float, default 1e-8
1112 Absolute tolerance. Only used when check_exact is False.
1113
1114 .. versionadded:: 1.1.0
1115 obj : str, default 'DataFrame'
1116 Specify object name being compared, internally used to show appropriate
1117 assertion message.
1118
1119 See Also
1120 --------
1121 assert_series_equal : Equivalent method for asserting Series equality.
1122 DataFrame.equals : Check DataFrame equality.
1123
1124 Examples
1125 --------
1126 This example shows comparing two DataFrames that are equal
1127 but with columns of differing dtypes.
1128
1129 >>> from pandas.testing import assert_frame_equal
1130 >>> df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
1131 >>> df2 = pd.DataFrame({'a': [1, 2], 'b': [3.0, 4.0]})
1132
1133 df1 equals itself.
1134
1135 >>> assert_frame_equal(df1, df1)
1136
1137 df1 differs from df2 as column 'b' is of a different type.
1138
1139 >>> assert_frame_equal(df1, df2)
1140 Traceback (most recent call last):
1141 ...
1142 AssertionError: Attributes of DataFrame.iloc[:, 1] (column name="b") are different
1143
1144 Attribute "dtype" are different
1145 [left]: int64
1146 [right]: float64
1147
1148 Ignore differing dtypes in columns with check_dtype.
1149
1150 >>> assert_frame_equal(df1, df2, check_dtype=False)
1151 """
1152 __tracebackhide__ = True
1153
1154 # instance validation
1155 _check_isinstance(left, right, DataFrame)
1156
1157 if check_frame_type:
1158 assert isinstance(left, type(right))
1159 # assert_class_equal(left, right, obj=obj)
1160
1161 # shape comparison
1162 if left.shape != right.shape:
1163 raise_assert_detail(
1164 obj, f"{obj} shape mismatch", f"{repr(left.shape)}", f"{repr(right.shape)}"
1165 )
1166
1167 if check_flags:
1168 assert left.flags == right.flags, f"{repr(left.flags)} != {repr(right.flags)}"
1169
1170 # index comparison
1171 assert_index_equal(
1172 left.index,
1173 right.index,
1174 exact=check_index_type,
1175 check_names=check_names,
1176 check_exact=check_exact,
1177 check_categorical=check_categorical,
1178 check_order=not check_like,
1179 rtol=rtol,
1180 atol=atol,
1181 obj=f"{obj}.index",
1182 )
1183
1184 # column comparison
1185 assert_index_equal(
1186 left.columns,
1187 right.columns,
1188 exact=check_column_type,
1189 check_names=check_names,
1190 check_exact=check_exact,
1191 check_categorical=check_categorical,
1192 check_order=not check_like,
1193 rtol=rtol,
1194 atol=atol,
1195 obj=f"{obj}.columns",
1196 )
1197
1198 if check_like:
1199 left = left.reindex_like(right)
1200
1201 # compare by blocks
1202 if by_blocks:
1203 rblocks = right._to_dict_of_blocks()
1204 lblocks = left._to_dict_of_blocks()
1205 for dtype in list(set(list(lblocks.keys()) + list(rblocks.keys()))):
1206 assert dtype in lblocks
1207 assert dtype in rblocks
1208 assert_frame_equal(
1209 lblocks[dtype], rblocks[dtype], check_dtype=check_dtype, obj=obj
1210 )
1211
1212 # compare by columns
1213 else:
1214 for i, col in enumerate(left.columns):
1215 # We have already checked that columns match, so we can do
1216 # fast location-based lookups
1217 lcol = left._ixs(i, axis=1)
1218 rcol = right._ixs(i, axis=1)
1219
1220 # GH #38183
1221 # use check_index=False, because we do not want to run
1222 # assert_index_equal for each column,
1223 # as we already checked it for the whole dataframe before.
1224 assert_series_equal(
1225 lcol,
1226 rcol,
1227 check_dtype=check_dtype,
1228 check_index_type=check_index_type,
1229 check_exact=check_exact,
1230 check_names=check_names,
1231 check_datetimelike_compat=check_datetimelike_compat,
1232 check_categorical=check_categorical,
1233 check_freq=check_freq,
1234 obj=f'{obj}.iloc[:, {i}] (column name="{col}")',
1235 rtol=rtol,
1236 atol=atol,
1237 check_index=False,
1238 check_flags=False,
1239 )
1240
1241
1242def assert_equal(left, right, **kwargs) -> None:
1243 """
1244 Wrapper for tm.assert_*_equal to dispatch to the appropriate test function.
1245
1246 Parameters
1247 ----------
1248 left, right : Index, Series, DataFrame, ExtensionArray, or np.ndarray
1249 The two items to be compared.
1250 **kwargs
1251 All keyword arguments are passed through to the underlying assert method.
1252 """
1253 __tracebackhide__ = True
1254
1255 if isinstance(left, Index):
1256 assert_index_equal(left, right, **kwargs)
1257 if isinstance(left, (DatetimeIndex, TimedeltaIndex)):
1258 assert left.freq == right.freq, (left.freq, right.freq)
1259 elif isinstance(left, Series):
1260 assert_series_equal(left, right, **kwargs)
1261 elif isinstance(left, DataFrame):
1262 assert_frame_equal(left, right, **kwargs)
1263 elif isinstance(left, IntervalArray):
1264 assert_interval_array_equal(left, right, **kwargs)
1265 elif isinstance(left, PeriodArray):
1266 assert_period_array_equal(left, right, **kwargs)
1267 elif isinstance(left, DatetimeArray):
1268 assert_datetime_array_equal(left, right, **kwargs)
1269 elif isinstance(left, TimedeltaArray):
1270 assert_timedelta_array_equal(left, right, **kwargs)
1271 elif isinstance(left, ExtensionArray):
1272 assert_extension_array_equal(left, right, **kwargs)
1273 elif isinstance(left, np.ndarray):
1274 assert_numpy_array_equal(left, right, **kwargs)
1275 elif isinstance(left, str):
1276 assert kwargs == {}
1277 assert left == right
1278 else:
1279 assert kwargs == {}
1280 assert_almost_equal(left, right)
1281
1282
1283def assert_sp_array_equal(left, right) -> None:
1284 """
1285 Check that the left and right SparseArray are equal.
1286
1287 Parameters
1288 ----------
1289 left : SparseArray
1290 right : SparseArray
1291 """
1292 _check_isinstance(left, right, pd.arrays.SparseArray)
1293
1294 assert_numpy_array_equal(left.sp_values, right.sp_values)
1295
1296 # SparseIndex comparison
1297 assert isinstance(left.sp_index, SparseIndex)
1298 assert isinstance(right.sp_index, SparseIndex)
1299
1300 left_index = left.sp_index
1301 right_index = right.sp_index
1302
1303 if not left_index.equals(right_index):
1304 raise_assert_detail(
1305 "SparseArray.index", "index are not equal", left_index, right_index
1306 )
1307 else:
1308 # Just ensure a
1309 pass
1310
1311 assert_attr_equal("fill_value", left, right)
1312 assert_attr_equal("dtype", left, right)
1313 assert_numpy_array_equal(left.to_dense(), right.to_dense())
1314
1315
1316def assert_contains_all(iterable, dic) -> None:
1317 for k in iterable:
1318 assert k in dic, f"Did not contain item: {repr(k)}"
1319
1320
1321def assert_copy(iter1, iter2, **eql_kwargs) -> None:
1322 """
1323 iter1, iter2: iterables that produce elements
1324 comparable with assert_almost_equal
1325
1326 Checks that the elements are equal, but not
1327 the same object. (Does not check that items
1328 in sequences are also not the same object)
1329 """
1330 for elem1, elem2 in zip(iter1, iter2):
1331 assert_almost_equal(elem1, elem2, **eql_kwargs)
1332 msg = (
1333 f"Expected object {repr(type(elem1))} and object {repr(type(elem2))} to be "
1334 "different objects, but they were the same object."
1335 )
1336 assert elem1 is not elem2, msg
1337
1338
1339def is_extension_array_dtype_and_needs_i8_conversion(left_dtype, right_dtype) -> bool:
1340 """
1341 Checks that we have the combination of an ExtensionArraydtype and
1342 a dtype that should be converted to int64
1343
1344 Returns
1345 -------
1346 bool
1347
1348 Related to issue #37609
1349 """
1350 return is_extension_array_dtype(left_dtype) and needs_i8_conversion(right_dtype)
1351
1352
1353def assert_indexing_slices_equivalent(ser: Series, l_slc: slice, i_slc: slice) -> None:
1354 """
1355 Check that ser.iloc[i_slc] matches ser.loc[l_slc] and, if applicable,
1356 ser[l_slc].
1357 """
1358 expected = ser.iloc[i_slc]
1359
1360 assert_series_equal(ser.loc[l_slc], expected)
1361
1362 if not is_integer_dtype(ser.index):
1363 # For integer indices, .loc and plain getitem are position-based.
1364 assert_series_equal(ser[l_slc], expected)
1365
1366
1367def assert_metadata_equivalent(
1368 left: DataFrame | Series, right: DataFrame | Series | None = None
1369) -> None:
1370 """
1371 Check that ._metadata attributes are equivalent.
1372 """
1373 for attr in left._metadata:
1374 val = getattr(left, attr, None)
1375 if right is None:
1376 assert val is None
1377 else:
1378 assert val == getattr(right, attr, None)