1from __future__ import annotations
2
3import gzip
4import io
5import pathlib
6import tarfile
7from typing import (
8 TYPE_CHECKING,
9 Any,
10 Callable,
11)
12import uuid
13import zipfile
14
15from pandas.compat import (
16 get_bz2_file,
17 get_lzma_file,
18)
19from pandas.compat._optional import import_optional_dependency
20
21import pandas as pd
22from pandas._testing.contexts import ensure_clean
23
24if TYPE_CHECKING:
25 from pandas._typing import (
26 FilePath,
27 ReadPickleBuffer,
28 )
29
30 from pandas import (
31 DataFrame,
32 Series,
33 )
34
35# ------------------------------------------------------------------
36# File-IO
37
38
39def round_trip_pickle(
40 obj: Any, path: FilePath | ReadPickleBuffer | None = None
41) -> DataFrame | Series:
42 """
43 Pickle an object and then read it again.
44
45 Parameters
46 ----------
47 obj : any object
48 The object to pickle and then re-read.
49 path : str, path object or file-like object, default None
50 The path where the pickled object is written and then read.
51
52 Returns
53 -------
54 pandas object
55 The original object that was pickled and then re-read.
56 """
57 _path = path
58 if _path is None:
59 _path = f"__{uuid.uuid4()}__.pickle"
60 with ensure_clean(_path) as temp_path:
61 pd.to_pickle(obj, temp_path)
62 return pd.read_pickle(temp_path)
63
64
65def round_trip_pathlib(writer, reader, path: str | None = None):
66 """
67 Write an object to file specified by a pathlib.Path and read it back
68
69 Parameters
70 ----------
71 writer : callable bound to pandas object
72 IO writing function (e.g. DataFrame.to_csv )
73 reader : callable
74 IO reading function (e.g. pd.read_csv )
75 path : str, default None
76 The path where the object is written and then read.
77
78 Returns
79 -------
80 pandas object
81 The original object that was serialized and then re-read.
82 """
83 Path = pathlib.Path
84 if path is None:
85 path = "___pathlib___"
86 with ensure_clean(path) as path:
87 writer(Path(path)) # type: ignore[arg-type]
88 obj = reader(Path(path)) # type: ignore[arg-type]
89 return obj
90
91
92def round_trip_localpath(writer, reader, path: str | None = None):
93 """
94 Write an object to file specified by a py.path LocalPath and read it back.
95
96 Parameters
97 ----------
98 writer : callable bound to pandas object
99 IO writing function (e.g. DataFrame.to_csv )
100 reader : callable
101 IO reading function (e.g. pd.read_csv )
102 path : str, default None
103 The path where the object is written and then read.
104
105 Returns
106 -------
107 pandas object
108 The original object that was serialized and then re-read.
109 """
110 import pytest
111
112 LocalPath = pytest.importorskip("py.path").local
113 if path is None:
114 path = "___localpath___"
115 with ensure_clean(path) as path:
116 writer(LocalPath(path))
117 obj = reader(LocalPath(path))
118 return obj
119
120
121def write_to_compressed(compression, path, data, dest: str = "test") -> None:
122 """
123 Write data to a compressed file.
124
125 Parameters
126 ----------
127 compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd'}
128 The compression type to use.
129 path : str
130 The file path to write the data.
131 data : str
132 The data to write.
133 dest : str, default "test"
134 The destination file (for ZIP only)
135
136 Raises
137 ------
138 ValueError : An invalid compression value was passed in.
139 """
140 args: tuple[Any, ...] = (data,)
141 mode = "wb"
142 method = "write"
143 compress_method: Callable
144
145 if compression == "zip":
146 compress_method = zipfile.ZipFile
147 mode = "w"
148 args = (dest, data)
149 method = "writestr"
150 elif compression == "tar":
151 compress_method = tarfile.TarFile
152 mode = "w"
153 file = tarfile.TarInfo(name=dest)
154 bytes = io.BytesIO(data)
155 file.size = len(data)
156 args = (file, bytes)
157 method = "addfile"
158 elif compression == "gzip":
159 compress_method = gzip.GzipFile
160 elif compression == "bz2":
161 compress_method = get_bz2_file()
162 elif compression == "zstd":
163 compress_method = import_optional_dependency("zstandard").open
164 elif compression == "xz":
165 compress_method = get_lzma_file()
166 else:
167 raise ValueError(f"Unrecognized compression type: {compression}")
168
169 with compress_method(path, mode=mode) as f:
170 getattr(f, method)(*args)