Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/pandas/_testing/__init__.py: 51%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

195 statements  

1from __future__ import annotations 

2 

3from decimal import Decimal 

4import operator 

5import os 

6from sys import byteorder 

7from typing import ( 

8 TYPE_CHECKING, 

9 Callable, 

10 ContextManager, 

11 cast, 

12) 

13import warnings 

14 

15import numpy as np 

16 

17from pandas._config.localization import ( 

18 can_set_locale, 

19 get_locales, 

20 set_locale, 

21) 

22 

23from pandas.compat import pa_version_under10p1 

24 

25from pandas.core.dtypes.common import is_string_dtype 

26 

27import pandas as pd 

28from pandas import ( 

29 ArrowDtype, 

30 DataFrame, 

31 Index, 

32 MultiIndex, 

33 RangeIndex, 

34 Series, 

35) 

36from pandas._testing._io import ( 

37 round_trip_localpath, 

38 round_trip_pathlib, 

39 round_trip_pickle, 

40 write_to_compressed, 

41) 

42from pandas._testing._warnings import ( 

43 assert_produces_warning, 

44 maybe_produces_warning, 

45) 

46from pandas._testing.asserters import ( 

47 assert_almost_equal, 

48 assert_attr_equal, 

49 assert_categorical_equal, 

50 assert_class_equal, 

51 assert_contains_all, 

52 assert_copy, 

53 assert_datetime_array_equal, 

54 assert_dict_equal, 

55 assert_equal, 

56 assert_extension_array_equal, 

57 assert_frame_equal, 

58 assert_index_equal, 

59 assert_indexing_slices_equivalent, 

60 assert_interval_array_equal, 

61 assert_is_sorted, 

62 assert_is_valid_plot_return_object, 

63 assert_metadata_equivalent, 

64 assert_numpy_array_equal, 

65 assert_period_array_equal, 

66 assert_series_equal, 

67 assert_sp_array_equal, 

68 assert_timedelta_array_equal, 

69 raise_assert_detail, 

70) 

71from pandas._testing.compat import ( 

72 get_dtype, 

73 get_obj, 

74) 

75from pandas._testing.contexts import ( 

76 assert_cow_warning, 

77 decompress_file, 

78 ensure_clean, 

79 raises_chained_assignment_error, 

80 set_timezone, 

81 use_numexpr, 

82 with_csv_dialect, 

83) 

84from pandas.core.arrays import ( 

85 BaseMaskedArray, 

86 ExtensionArray, 

87 NumpyExtensionArray, 

88) 

89from pandas.core.arrays._mixins import NDArrayBackedExtensionArray 

90from pandas.core.construction import extract_array 

91 

92if TYPE_CHECKING: 

93 from pandas._typing import ( 

94 Dtype, 

95 NpDtype, 

96 ) 

97 

98 from pandas.core.arrays import ArrowExtensionArray 

99 

100UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"] 

101UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"] 

102SIGNED_INT_NUMPY_DTYPES: list[NpDtype] = [int, "int8", "int16", "int32", "int64"] 

103SIGNED_INT_EA_DTYPES: list[Dtype] = ["Int8", "Int16", "Int32", "Int64"] 

104ALL_INT_NUMPY_DTYPES = UNSIGNED_INT_NUMPY_DTYPES + SIGNED_INT_NUMPY_DTYPES 

105ALL_INT_EA_DTYPES = UNSIGNED_INT_EA_DTYPES + SIGNED_INT_EA_DTYPES 

106ALL_INT_DTYPES: list[Dtype] = [*ALL_INT_NUMPY_DTYPES, *ALL_INT_EA_DTYPES] 

107 

108FLOAT_NUMPY_DTYPES: list[NpDtype] = [float, "float32", "float64"] 

109FLOAT_EA_DTYPES: list[Dtype] = ["Float32", "Float64"] 

110ALL_FLOAT_DTYPES: list[Dtype] = [*FLOAT_NUMPY_DTYPES, *FLOAT_EA_DTYPES] 

111 

112COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"] 

113STRING_DTYPES: list[Dtype] = [str, "str", "U"] 

114COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES] 

115 

116DATETIME64_DTYPES: list[Dtype] = ["datetime64[ns]", "M8[ns]"] 

117TIMEDELTA64_DTYPES: list[Dtype] = ["timedelta64[ns]", "m8[ns]"] 

118 

119BOOL_DTYPES: list[Dtype] = [bool, "bool"] 

120BYTES_DTYPES: list[Dtype] = [bytes, "bytes"] 

121OBJECT_DTYPES: list[Dtype] = [object, "object"] 

122 

123ALL_REAL_NUMPY_DTYPES = FLOAT_NUMPY_DTYPES + ALL_INT_NUMPY_DTYPES 

124ALL_REAL_EXTENSION_DTYPES = FLOAT_EA_DTYPES + ALL_INT_EA_DTYPES 

125ALL_REAL_DTYPES: list[Dtype] = [*ALL_REAL_NUMPY_DTYPES, *ALL_REAL_EXTENSION_DTYPES] 

126ALL_NUMERIC_DTYPES: list[Dtype] = [*ALL_REAL_DTYPES, *COMPLEX_DTYPES] 

127 

128ALL_NUMPY_DTYPES = ( 

129 ALL_REAL_NUMPY_DTYPES 

130 + COMPLEX_DTYPES 

131 + STRING_DTYPES 

132 + DATETIME64_DTYPES 

133 + TIMEDELTA64_DTYPES 

134 + BOOL_DTYPES 

135 + OBJECT_DTYPES 

136 + BYTES_DTYPES 

137) 

138 

139NARROW_NP_DTYPES = [ 

140 np.float16, 

141 np.float32, 

142 np.int8, 

143 np.int16, 

144 np.int32, 

145 np.uint8, 

146 np.uint16, 

147 np.uint32, 

148] 

149 

150PYTHON_DATA_TYPES = [ 

151 str, 

152 int, 

153 float, 

154 complex, 

155 list, 

156 tuple, 

157 range, 

158 dict, 

159 set, 

160 frozenset, 

161 bool, 

162 bytes, 

163 bytearray, 

164 memoryview, 

165] 

166 

167ENDIAN = {"little": "<", "big": ">"}[byteorder] 

168 

169NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")] 

170NP_NAT_OBJECTS = [ 

171 cls("NaT", unit) 

172 for cls in [np.datetime64, np.timedelta64] 

173 for unit in [ 

174 "Y", 

175 "M", 

176 "W", 

177 "D", 

178 "h", 

179 "m", 

180 "s", 

181 "ms", 

182 "us", 

183 "ns", 

184 "ps", 

185 "fs", 

186 "as", 

187 ] 

188] 

189 

190if not pa_version_under10p1: 

191 import pyarrow as pa 

192 

193 UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()] 

194 SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()] 

195 ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES 

196 ALL_INT_PYARROW_DTYPES_STR_REPR = [ 

197 str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES 

198 ] 

199 

200 # pa.float16 doesn't seem supported 

201 # https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86 

202 FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()] 

203 FLOAT_PYARROW_DTYPES_STR_REPR = [ 

204 str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES 

205 ] 

206 DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)] 

207 STRING_PYARROW_DTYPES = [pa.string()] 

208 BINARY_PYARROW_DTYPES = [pa.binary()] 

209 

210 TIME_PYARROW_DTYPES = [ 

211 pa.time32("s"), 

212 pa.time32("ms"), 

213 pa.time64("us"), 

214 pa.time64("ns"), 

215 ] 

216 DATE_PYARROW_DTYPES = [pa.date32(), pa.date64()] 

217 DATETIME_PYARROW_DTYPES = [ 

218 pa.timestamp(unit=unit, tz=tz) 

219 for unit in ["s", "ms", "us", "ns"] 

220 for tz in [None, "UTC", "US/Pacific", "US/Eastern"] 

221 ] 

222 TIMEDELTA_PYARROW_DTYPES = [pa.duration(unit) for unit in ["s", "ms", "us", "ns"]] 

223 

224 BOOL_PYARROW_DTYPES = [pa.bool_()] 

225 

226 # TODO: Add container like pyarrow types: 

227 # https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions 

228 ALL_PYARROW_DTYPES = ( 

229 ALL_INT_PYARROW_DTYPES 

230 + FLOAT_PYARROW_DTYPES 

231 + DECIMAL_PYARROW_DTYPES 

232 + STRING_PYARROW_DTYPES 

233 + BINARY_PYARROW_DTYPES 

234 + TIME_PYARROW_DTYPES 

235 + DATE_PYARROW_DTYPES 

236 + DATETIME_PYARROW_DTYPES 

237 + TIMEDELTA_PYARROW_DTYPES 

238 + BOOL_PYARROW_DTYPES 

239 ) 

240 ALL_REAL_PYARROW_DTYPES_STR_REPR = ( 

241 ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR 

242 ) 

243else: 

244 FLOAT_PYARROW_DTYPES_STR_REPR = [] 

245 ALL_INT_PYARROW_DTYPES_STR_REPR = [] 

246 ALL_PYARROW_DTYPES = [] 

247 ALL_REAL_PYARROW_DTYPES_STR_REPR = [] 

248 

249ALL_REAL_NULLABLE_DTYPES = ( 

250 FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR 

251) 

252 

253arithmetic_dunder_methods = [ 

254 "__add__", 

255 "__radd__", 

256 "__sub__", 

257 "__rsub__", 

258 "__mul__", 

259 "__rmul__", 

260 "__floordiv__", 

261 "__rfloordiv__", 

262 "__truediv__", 

263 "__rtruediv__", 

264 "__pow__", 

265 "__rpow__", 

266 "__mod__", 

267 "__rmod__", 

268] 

269 

270comparison_dunder_methods = ["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"] 

271 

272 

273# ----------------------------------------------------------------------------- 

274# Comparators 

275 

276 

277def box_expected(expected, box_cls, transpose: bool = True): 

278 """ 

279 Helper function to wrap the expected output of a test in a given box_class. 

280 

281 Parameters 

282 ---------- 

283 expected : np.ndarray, Index, Series 

284 box_cls : {Index, Series, DataFrame} 

285 

286 Returns 

287 ------- 

288 subclass of box_cls 

289 """ 

290 if box_cls is pd.array: 

291 if isinstance(expected, RangeIndex): 

292 # pd.array would return an IntegerArray 

293 expected = NumpyExtensionArray(np.asarray(expected._values)) 

294 else: 

295 expected = pd.array(expected, copy=False) 

296 elif box_cls is Index: 

297 with warnings.catch_warnings(): 

298 warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning) 

299 expected = Index(expected) 

300 elif box_cls is Series: 

301 with warnings.catch_warnings(): 

302 warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning) 

303 expected = Series(expected) 

304 elif box_cls is DataFrame: 

305 with warnings.catch_warnings(): 

306 warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning) 

307 expected = Series(expected).to_frame() 

308 if transpose: 

309 # for vector operations, we need a DataFrame to be a single-row, 

310 # not a single-column, in order to operate against non-DataFrame 

311 # vectors of the same length. But convert to two rows to avoid 

312 # single-row special cases in datetime arithmetic 

313 expected = expected.T 

314 expected = pd.concat([expected] * 2, ignore_index=True) 

315 elif box_cls is np.ndarray or box_cls is np.array: 

316 expected = np.array(expected) 

317 elif box_cls is to_array: 

318 expected = to_array(expected) 

319 else: 

320 raise NotImplementedError(box_cls) 

321 return expected 

322 

323 

324def to_array(obj): 

325 """ 

326 Similar to pd.array, but does not cast numpy dtypes to nullable dtypes. 

327 """ 

328 # temporary implementation until we get pd.array in place 

329 dtype = getattr(obj, "dtype", None) 

330 

331 if dtype is None: 

332 return np.asarray(obj) 

333 

334 return extract_array(obj, extract_numpy=True) 

335 

336 

337class SubclassedSeries(Series): 

338 _metadata = ["testattr", "name"] 

339 

340 @property 

341 def _constructor(self): 

342 # For testing, those properties return a generic callable, and not 

343 # the actual class. In this case that is equivalent, but it is to 

344 # ensure we don't rely on the property returning a class 

345 # See https://github.com/pandas-dev/pandas/pull/46018 and 

346 # https://github.com/pandas-dev/pandas/issues/32638 and linked issues 

347 return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs) 

348 

349 @property 

350 def _constructor_expanddim(self): 

351 return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs) 

352 

353 

354class SubclassedDataFrame(DataFrame): 

355 _metadata = ["testattr"] 

356 

357 @property 

358 def _constructor(self): 

359 return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs) 

360 

361 @property 

362 def _constructor_sliced(self): 

363 return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs) 

364 

365 

366def convert_rows_list_to_csv_str(rows_list: list[str]) -> str: 

367 """ 

368 Convert list of CSV rows to single CSV-formatted string for current OS. 

369 

370 This method is used for creating expected value of to_csv() method. 

371 

372 Parameters 

373 ---------- 

374 rows_list : List[str] 

375 Each element represents the row of csv. 

376 

377 Returns 

378 ------- 

379 str 

380 Expected output of to_csv() in current OS. 

381 """ 

382 sep = os.linesep 

383 return sep.join(rows_list) + sep 

384 

385 

386def external_error_raised(expected_exception: type[Exception]) -> ContextManager: 

387 """ 

388 Helper function to mark pytest.raises that have an external error message. 

389 

390 Parameters 

391 ---------- 

392 expected_exception : Exception 

393 Expected error to raise. 

394 

395 Returns 

396 ------- 

397 Callable 

398 Regular `pytest.raises` function with `match` equal to `None`. 

399 """ 

400 import pytest 

401 

402 return pytest.raises(expected_exception, match=None) 

403 

404 

405cython_table = pd.core.common._cython_table.items() 

406 

407 

408def get_cython_table_params(ndframe, func_names_and_expected): 

409 """ 

410 Combine frame, functions from com._cython_table 

411 keys and expected result. 

412 

413 Parameters 

414 ---------- 

415 ndframe : DataFrame or Series 

416 func_names_and_expected : Sequence of two items 

417 The first item is a name of a NDFrame method ('sum', 'prod') etc. 

418 The second item is the expected return value. 

419 

420 Returns 

421 ------- 

422 list 

423 List of three items (DataFrame, function, expected result) 

424 """ 

425 results = [] 

426 for func_name, expected in func_names_and_expected: 

427 results.append((ndframe, func_name, expected)) 

428 results += [ 

429 (ndframe, func, expected) 

430 for func, name in cython_table 

431 if name == func_name 

432 ] 

433 return results 

434 

435 

436def get_op_from_name(op_name: str) -> Callable: 

437 """ 

438 The operator function for a given op name. 

439 

440 Parameters 

441 ---------- 

442 op_name : str 

443 The op name, in form of "add" or "__add__". 

444 

445 Returns 

446 ------- 

447 function 

448 A function performing the operation. 

449 """ 

450 short_opname = op_name.strip("_") 

451 try: 

452 op = getattr(operator, short_opname) 

453 except AttributeError: 

454 # Assume it is the reverse operator 

455 rop = getattr(operator, short_opname[1:]) 

456 op = lambda x, y: rop(y, x) 

457 

458 return op 

459 

460 

461# ----------------------------------------------------------------------------- 

462# Indexing test helpers 

463 

464 

465def getitem(x): 

466 return x 

467 

468 

469def setitem(x): 

470 return x 

471 

472 

473def loc(x): 

474 return x.loc 

475 

476 

477def iloc(x): 

478 return x.iloc 

479 

480 

481def at(x): 

482 return x.at 

483 

484 

485def iat(x): 

486 return x.iat 

487 

488 

489# ----------------------------------------------------------------------------- 

490 

491_UNITS = ["s", "ms", "us", "ns"] 

492 

493 

494def get_finest_unit(left: str, right: str): 

495 """ 

496 Find the higher of two datetime64 units. 

497 """ 

498 if _UNITS.index(left) >= _UNITS.index(right): 

499 return left 

500 return right 

501 

502 

503def shares_memory(left, right) -> bool: 

504 """ 

505 Pandas-compat for np.shares_memory. 

506 """ 

507 if isinstance(left, np.ndarray) and isinstance(right, np.ndarray): 

508 return np.shares_memory(left, right) 

509 elif isinstance(left, np.ndarray): 

510 # Call with reversed args to get to unpacking logic below. 

511 return shares_memory(right, left) 

512 

513 if isinstance(left, RangeIndex): 

514 return False 

515 if isinstance(left, MultiIndex): 

516 return shares_memory(left._codes, right) 

517 if isinstance(left, (Index, Series)): 

518 return shares_memory(left._values, right) 

519 

520 if isinstance(left, NDArrayBackedExtensionArray): 

521 return shares_memory(left._ndarray, right) 

522 if isinstance(left, pd.core.arrays.SparseArray): 

523 return shares_memory(left.sp_values, right) 

524 if isinstance(left, pd.core.arrays.IntervalArray): 

525 return shares_memory(left._left, right) or shares_memory(left._right, right) 

526 

527 if ( 

528 isinstance(left, ExtensionArray) 

529 and is_string_dtype(left.dtype) 

530 and left.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined] 

531 ): 

532 # https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669 

533 left = cast("ArrowExtensionArray", left) 

534 if ( 

535 isinstance(right, ExtensionArray) 

536 and is_string_dtype(right.dtype) 

537 and right.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined] 

538 ): 

539 right = cast("ArrowExtensionArray", right) 

540 left_pa_data = left._pa_array 

541 right_pa_data = right._pa_array 

542 left_buf1 = left_pa_data.chunk(0).buffers()[1] 

543 right_buf1 = right_pa_data.chunk(0).buffers()[1] 

544 return left_buf1 == right_buf1 

545 

546 if isinstance(left, BaseMaskedArray) and isinstance(right, BaseMaskedArray): 

547 # By convention, we'll say these share memory if they share *either* 

548 # the _data or the _mask 

549 return np.shares_memory(left._data, right._data) or np.shares_memory( 

550 left._mask, right._mask 

551 ) 

552 

553 if isinstance(left, DataFrame) and len(left._mgr.arrays) == 1: 

554 arr = left._mgr.arrays[0] 

555 return shares_memory(arr, right) 

556 

557 raise NotImplementedError(type(left), type(right)) 

558 

559 

560__all__ = [ 

561 "ALL_INT_EA_DTYPES", 

562 "ALL_INT_NUMPY_DTYPES", 

563 "ALL_NUMPY_DTYPES", 

564 "ALL_REAL_NUMPY_DTYPES", 

565 "assert_almost_equal", 

566 "assert_attr_equal", 

567 "assert_categorical_equal", 

568 "assert_class_equal", 

569 "assert_contains_all", 

570 "assert_copy", 

571 "assert_datetime_array_equal", 

572 "assert_dict_equal", 

573 "assert_equal", 

574 "assert_extension_array_equal", 

575 "assert_frame_equal", 

576 "assert_index_equal", 

577 "assert_indexing_slices_equivalent", 

578 "assert_interval_array_equal", 

579 "assert_is_sorted", 

580 "assert_is_valid_plot_return_object", 

581 "assert_metadata_equivalent", 

582 "assert_numpy_array_equal", 

583 "assert_period_array_equal", 

584 "assert_produces_warning", 

585 "assert_series_equal", 

586 "assert_sp_array_equal", 

587 "assert_timedelta_array_equal", 

588 "assert_cow_warning", 

589 "at", 

590 "BOOL_DTYPES", 

591 "box_expected", 

592 "BYTES_DTYPES", 

593 "can_set_locale", 

594 "COMPLEX_DTYPES", 

595 "convert_rows_list_to_csv_str", 

596 "DATETIME64_DTYPES", 

597 "decompress_file", 

598 "ENDIAN", 

599 "ensure_clean", 

600 "external_error_raised", 

601 "FLOAT_EA_DTYPES", 

602 "FLOAT_NUMPY_DTYPES", 

603 "get_cython_table_params", 

604 "get_dtype", 

605 "getitem", 

606 "get_locales", 

607 "get_finest_unit", 

608 "get_obj", 

609 "get_op_from_name", 

610 "iat", 

611 "iloc", 

612 "loc", 

613 "maybe_produces_warning", 

614 "NARROW_NP_DTYPES", 

615 "NP_NAT_OBJECTS", 

616 "NULL_OBJECTS", 

617 "OBJECT_DTYPES", 

618 "raise_assert_detail", 

619 "raises_chained_assignment_error", 

620 "round_trip_localpath", 

621 "round_trip_pathlib", 

622 "round_trip_pickle", 

623 "setitem", 

624 "set_locale", 

625 "set_timezone", 

626 "shares_memory", 

627 "SIGNED_INT_EA_DTYPES", 

628 "SIGNED_INT_NUMPY_DTYPES", 

629 "STRING_DTYPES", 

630 "SubclassedDataFrame", 

631 "SubclassedSeries", 

632 "TIMEDELTA64_DTYPES", 

633 "to_array", 

634 "UNSIGNED_INT_EA_DTYPES", 

635 "UNSIGNED_INT_NUMPY_DTYPES", 

636 "use_numexpr", 

637 "with_csv_dialect", 

638 "write_to_compressed", 

639]