1from __future__ import annotations
2
3import bz2
4from functools import wraps
5import gzip
6import io
7import socket
8import tarfile
9from typing import (
10 TYPE_CHECKING,
11 Any,
12 Callable,
13)
14import zipfile
15
16from pandas._typing import (
17 FilePath,
18 ReadPickleBuffer,
19)
20from pandas.compat import get_lzma_file
21from pandas.compat._optional import import_optional_dependency
22
23import pandas as pd
24from pandas._testing._random import rands
25from pandas._testing.contexts import ensure_clean
26
27from pandas.io.common import urlopen
28
29if TYPE_CHECKING:
30 from pandas import (
31 DataFrame,
32 Series,
33 )
34
35# skip tests on exceptions with these messages
36_network_error_messages = (
37 # 'urlopen error timed out',
38 # 'timeout: timed out',
39 # 'socket.timeout: timed out',
40 "timed out",
41 "Server Hangup",
42 "HTTP Error 503: Service Unavailable",
43 "502: Proxy Error",
44 "HTTP Error 502: internal error",
45 "HTTP Error 502",
46 "HTTP Error 503",
47 "HTTP Error 403",
48 "HTTP Error 400",
49 "Temporary failure in name resolution",
50 "Name or service not known",
51 "Connection refused",
52 "certificate verify",
53)
54
55# or this e.errno/e.reason.errno
56_network_errno_vals = (
57 101, # Network is unreachable
58 111, # Connection refused
59 110, # Connection timed out
60 104, # Connection reset Error
61 54, # Connection reset by peer
62 60, # urllib.error.URLError: [Errno 60] Connection timed out
63)
64
65# Both of the above shouldn't mask real issues such as 404's
66# or refused connections (changed DNS).
67# But some tests (test_data yahoo) contact incredibly flakey
68# servers.
69
70# and conditionally raise on exception types in _get_default_network_errors
71
72
73def _get_default_network_errors():
74 # Lazy import for http.client & urllib.error
75 # because it imports many things from the stdlib
76 import http.client
77 import urllib.error
78
79 return (
80 OSError,
81 http.client.HTTPException,
82 TimeoutError,
83 urllib.error.URLError,
84 socket.timeout,
85 )
86
87
88def optional_args(decorator):
89 """
90 allows a decorator to take optional positional and keyword arguments.
91 Assumes that taking a single, callable, positional argument means that
92 it is decorating a function, i.e. something like this::
93
94 @my_decorator
95 def function(): pass
96
97 Calls decorator with decorator(f, *args, **kwargs)
98 """
99
100 @wraps(decorator)
101 def wrapper(*args, **kwargs):
102 def dec(f):
103 return decorator(f, *args, **kwargs)
104
105 is_decorating = not kwargs and len(args) == 1 and callable(args[0])
106 if is_decorating:
107 f = args[0]
108 args = ()
109 return dec(f)
110 else:
111 return dec
112
113 return wrapper
114
115
116# error: Untyped decorator makes function "network" untyped
117@optional_args # type: ignore[misc]
118def network(
119 t,
120 url: str = "https://www.google.com",
121 raise_on_error: bool = False,
122 check_before_test: bool = False,
123 error_classes=None,
124 skip_errnos=_network_errno_vals,
125 _skip_on_messages=_network_error_messages,
126):
127 """
128 Label a test as requiring network connection and, if an error is
129 encountered, only raise if it does not find a network connection.
130
131 In comparison to ``network``, this assumes an added contract to your test:
132 you must assert that, under normal conditions, your test will ONLY fail if
133 it does not have network connectivity.
134
135 You can call this in 3 ways: as a standard decorator, with keyword
136 arguments, or with a positional argument that is the url to check.
137
138 Parameters
139 ----------
140 t : callable
141 The test requiring network connectivity.
142 url : path
143 The url to test via ``pandas.io.common.urlopen`` to check
144 for connectivity. Defaults to 'https://www.google.com'.
145 raise_on_error : bool
146 If True, never catches errors.
147 check_before_test : bool
148 If True, checks connectivity before running the test case.
149 error_classes : tuple or Exception
150 error classes to ignore. If not in ``error_classes``, raises the error.
151 defaults to OSError. Be careful about changing the error classes here.
152 skip_errnos : iterable of int
153 Any exception that has .errno or .reason.erno set to one
154 of these values will be skipped with an appropriate
155 message.
156 _skip_on_messages: iterable of string
157 any exception e for which one of the strings is
158 a substring of str(e) will be skipped with an appropriate
159 message. Intended to suppress errors where an errno isn't available.
160
161 Notes
162 -----
163 * ``raise_on_error`` supersedes ``check_before_test``
164
165 Returns
166 -------
167 t : callable
168 The decorated test ``t``, with checks for connectivity errors.
169
170 Example
171 -------
172
173 Tests decorated with @network will fail if it's possible to make a network
174 connection to another URL (defaults to google.com)::
175
176 >>> from pandas import _testing as tm
177 >>> @tm.network
178 ... def test_network():
179 ... with pd.io.common.urlopen("rabbit://bonanza.com"):
180 ... pass
181 >>> test_network() # doctest: +SKIP
182 Traceback
183 ...
184 URLError: <urlopen error unknown url type: rabbit>
185
186 You can specify alternative URLs::
187
188 >>> @tm.network("https://www.yahoo.com")
189 ... def test_something_with_yahoo():
190 ... raise OSError("Failure Message")
191 >>> test_something_with_yahoo() # doctest: +SKIP
192 Traceback (most recent call last):
193 ...
194 OSError: Failure Message
195
196 If you set check_before_test, it will check the url first and not run the
197 test on failure::
198
199 >>> @tm.network("failing://url.blaher", check_before_test=True)
200 ... def test_something():
201 ... print("I ran!")
202 ... raise ValueError("Failure")
203 >>> test_something() # doctest: +SKIP
204 Traceback (most recent call last):
205 ...
206
207 Errors not related to networking will always be raised.
208 """
209 import pytest
210
211 if error_classes is None:
212 error_classes = _get_default_network_errors()
213
214 t.network = True
215
216 @wraps(t)
217 def wrapper(*args, **kwargs):
218 if (
219 check_before_test
220 and not raise_on_error
221 and not can_connect(url, error_classes)
222 ):
223 pytest.skip(
224 f"May not have network connectivity because cannot connect to {url}"
225 )
226 try:
227 return t(*args, **kwargs)
228 except Exception as err:
229 errno = getattr(err, "errno", None)
230 if not errno and hasattr(errno, "reason"):
231 # error: "Exception" has no attribute "reason"
232 errno = getattr(err.reason, "errno", None) # type: ignore[attr-defined]
233
234 if errno in skip_errnos:
235 pytest.skip(f"Skipping test due to known errno and error {err}")
236
237 e_str = str(err)
238
239 if any(m.lower() in e_str.lower() for m in _skip_on_messages):
240 pytest.skip(
241 f"Skipping test because exception message is known and error {err}"
242 )
243
244 if not isinstance(err, error_classes) or raise_on_error:
245 raise
246 pytest.skip(f"Skipping test due to lack of connectivity and error {err}")
247
248 return wrapper
249
250
251def can_connect(url, error_classes=None) -> bool:
252 """
253 Try to connect to the given url. True if succeeds, False if OSError
254 raised
255
256 Parameters
257 ----------
258 url : basestring
259 The URL to try to connect to
260
261 Returns
262 -------
263 connectable : bool
264 Return True if no OSError (unable to connect) or URLError (bad url) was
265 raised
266 """
267 if error_classes is None:
268 error_classes = _get_default_network_errors()
269
270 try:
271 with urlopen(url, timeout=20) as response:
272 # Timeout just in case rate-limiting is applied
273 if response.status != 200:
274 return False
275 except error_classes:
276 return False
277 else:
278 return True
279
280
281# ------------------------------------------------------------------
282# File-IO
283
284
285def round_trip_pickle(
286 obj: Any, path: FilePath | ReadPickleBuffer | None = None
287) -> DataFrame | Series:
288 """
289 Pickle an object and then read it again.
290
291 Parameters
292 ----------
293 obj : any object
294 The object to pickle and then re-read.
295 path : str, path object or file-like object, default None
296 The path where the pickled object is written and then read.
297
298 Returns
299 -------
300 pandas object
301 The original object that was pickled and then re-read.
302 """
303 _path = path
304 if _path is None:
305 _path = f"__{rands(10)}__.pickle"
306 with ensure_clean(_path) as temp_path:
307 pd.to_pickle(obj, temp_path)
308 return pd.read_pickle(temp_path)
309
310
311def round_trip_pathlib(writer, reader, path: str | None = None):
312 """
313 Write an object to file specified by a pathlib.Path and read it back
314
315 Parameters
316 ----------
317 writer : callable bound to pandas object
318 IO writing function (e.g. DataFrame.to_csv )
319 reader : callable
320 IO reading function (e.g. pd.read_csv )
321 path : str, default None
322 The path where the object is written and then read.
323
324 Returns
325 -------
326 pandas object
327 The original object that was serialized and then re-read.
328 """
329 import pytest
330
331 Path = pytest.importorskip("pathlib").Path
332 if path is None:
333 path = "___pathlib___"
334 with ensure_clean(path) as path:
335 writer(Path(path))
336 obj = reader(Path(path))
337 return obj
338
339
340def round_trip_localpath(writer, reader, path: str | None = None):
341 """
342 Write an object to file specified by a py.path LocalPath and read it back.
343
344 Parameters
345 ----------
346 writer : callable bound to pandas object
347 IO writing function (e.g. DataFrame.to_csv )
348 reader : callable
349 IO reading function (e.g. pd.read_csv )
350 path : str, default None
351 The path where the object is written and then read.
352
353 Returns
354 -------
355 pandas object
356 The original object that was serialized and then re-read.
357 """
358 import pytest
359
360 LocalPath = pytest.importorskip("py.path").local
361 if path is None:
362 path = "___localpath___"
363 with ensure_clean(path) as path:
364 writer(LocalPath(path))
365 obj = reader(LocalPath(path))
366 return obj
367
368
369def write_to_compressed(compression, path, data, dest: str = "test"):
370 """
371 Write data to a compressed file.
372
373 Parameters
374 ----------
375 compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd'}
376 The compression type to use.
377 path : str
378 The file path to write the data.
379 data : str
380 The data to write.
381 dest : str, default "test"
382 The destination file (for ZIP only)
383
384 Raises
385 ------
386 ValueError : An invalid compression value was passed in.
387 """
388 args: tuple[Any, ...] = (data,)
389 mode = "wb"
390 method = "write"
391 compress_method: Callable
392
393 if compression == "zip":
394 compress_method = zipfile.ZipFile
395 mode = "w"
396 args = (dest, data)
397 method = "writestr"
398 elif compression == "tar":
399 compress_method = tarfile.TarFile
400 mode = "w"
401 file = tarfile.TarInfo(name=dest)
402 bytes = io.BytesIO(data)
403 file.size = len(data)
404 args = (file, bytes)
405 method = "addfile"
406 elif compression == "gzip":
407 compress_method = gzip.GzipFile
408 elif compression == "bz2":
409 compress_method = bz2.BZ2File
410 elif compression == "zstd":
411 compress_method = import_optional_dependency("zstandard").open
412 elif compression == "xz":
413 compress_method = get_lzma_file()
414 else:
415 raise ValueError(f"Unrecognized compression type: {compression}")
416
417 with compress_method(path, mode=mode) as f:
418 getattr(f, method)(*args)
419
420
421# ------------------------------------------------------------------
422# Plotting
423
424
425def close(fignum=None) -> None:
426 from matplotlib.pyplot import (
427 close as _close,
428 get_fignums,
429 )
430
431 if fignum is None:
432 for fignum in get_fignums():
433 _close(fignum)
434 else:
435 _close(fignum)