1from __future__ import annotations
2
3from typing import TYPE_CHECKING
4import warnings
5
6from pandas._config import using_pyarrow_string_dtype
7
8from pandas._libs import lib
9from pandas.compat._optional import import_optional_dependency
10from pandas.errors import (
11 ParserError,
12 ParserWarning,
13)
14from pandas.util._exceptions import find_stack_level
15
16from pandas.core.dtypes.common import pandas_dtype
17from pandas.core.dtypes.inference import is_integer
18
19import pandas as pd
20from pandas import DataFrame
21
22from pandas.io._util import (
23 _arrow_dtype_mapping,
24 arrow_string_types_mapper,
25)
26from pandas.io.parsers.base_parser import ParserBase
27
28if TYPE_CHECKING:
29 from pandas._typing import ReadBuffer
30
31
32class ArrowParserWrapper(ParserBase):
33 """
34 Wrapper for the pyarrow engine for read_csv()
35 """
36
37 def __init__(self, src: ReadBuffer[bytes], **kwds) -> None:
38 super().__init__(kwds)
39 self.kwds = kwds
40 self.src = src
41
42 self._parse_kwds()
43
44 def _parse_kwds(self) -> None:
45 """
46 Validates keywords before passing to pyarrow.
47 """
48 encoding: str | None = self.kwds.get("encoding")
49 self.encoding = "utf-8" if encoding is None else encoding
50
51 na_values = self.kwds["na_values"]
52 if isinstance(na_values, dict):
53 raise ValueError(
54 "The pyarrow engine doesn't support passing a dict for na_values"
55 )
56 self.na_values = list(self.kwds["na_values"])
57
58 def _get_pyarrow_options(self) -> None:
59 """
60 Rename some arguments to pass to pyarrow
61 """
62 mapping = {
63 "usecols": "include_columns",
64 "na_values": "null_values",
65 "escapechar": "escape_char",
66 "skip_blank_lines": "ignore_empty_lines",
67 "decimal": "decimal_point",
68 "quotechar": "quote_char",
69 }
70 for pandas_name, pyarrow_name in mapping.items():
71 if pandas_name in self.kwds and self.kwds.get(pandas_name) is not None:
72 self.kwds[pyarrow_name] = self.kwds.pop(pandas_name)
73
74 # Date format handling
75 # If we get a string, we need to convert it into a list for pyarrow
76 # If we get a dict, we want to parse those separately
77 date_format = self.date_format
78 if isinstance(date_format, str):
79 date_format = [date_format]
80 else:
81 # In case of dict, we don't want to propagate through, so
82 # just set to pyarrow default of None
83
84 # Ideally, in future we disable pyarrow dtype inference (read in as string)
85 # to prevent misreads.
86 date_format = None
87 self.kwds["timestamp_parsers"] = date_format
88
89 self.parse_options = {
90 option_name: option_value
91 for option_name, option_value in self.kwds.items()
92 if option_value is not None
93 and option_name
94 in ("delimiter", "quote_char", "escape_char", "ignore_empty_lines")
95 }
96
97 on_bad_lines = self.kwds.get("on_bad_lines")
98 if on_bad_lines is not None:
99 if callable(on_bad_lines):
100 self.parse_options["invalid_row_handler"] = on_bad_lines
101 elif on_bad_lines == ParserBase.BadLineHandleMethod.ERROR:
102 self.parse_options[
103 "invalid_row_handler"
104 ] = None # PyArrow raises an exception by default
105 elif on_bad_lines == ParserBase.BadLineHandleMethod.WARN:
106
107 def handle_warning(invalid_row) -> str:
108 warnings.warn(
109 f"Expected {invalid_row.expected_columns} columns, but found "
110 f"{invalid_row.actual_columns}: {invalid_row.text}",
111 ParserWarning,
112 stacklevel=find_stack_level(),
113 )
114 return "skip"
115
116 self.parse_options["invalid_row_handler"] = handle_warning
117 elif on_bad_lines == ParserBase.BadLineHandleMethod.SKIP:
118 self.parse_options["invalid_row_handler"] = lambda _: "skip"
119
120 self.convert_options = {
121 option_name: option_value
122 for option_name, option_value in self.kwds.items()
123 if option_value is not None
124 and option_name
125 in (
126 "include_columns",
127 "null_values",
128 "true_values",
129 "false_values",
130 "decimal_point",
131 "timestamp_parsers",
132 )
133 }
134 self.convert_options["strings_can_be_null"] = "" in self.kwds["null_values"]
135 # autogenerated column names are prefixed with 'f' in pyarrow.csv
136 if self.header is None and "include_columns" in self.convert_options:
137 self.convert_options["include_columns"] = [
138 f"f{n}" for n in self.convert_options["include_columns"]
139 ]
140
141 self.read_options = {
142 "autogenerate_column_names": self.header is None,
143 "skip_rows": self.header
144 if self.header is not None
145 else self.kwds["skiprows"],
146 "encoding": self.encoding,
147 }
148
149 def _finalize_pandas_output(self, frame: DataFrame) -> DataFrame:
150 """
151 Processes data read in based on kwargs.
152
153 Parameters
154 ----------
155 frame: DataFrame
156 The DataFrame to process.
157
158 Returns
159 -------
160 DataFrame
161 The processed DataFrame.
162 """
163 num_cols = len(frame.columns)
164 multi_index_named = True
165 if self.header is None:
166 if self.names is None:
167 if self.header is None:
168 self.names = range(num_cols)
169 if len(self.names) != num_cols:
170 # usecols is passed through to pyarrow, we only handle index col here
171 # The only way self.names is not the same length as number of cols is
172 # if we have int index_col. We should just pad the names(they will get
173 # removed anyways) to expected length then.
174 self.names = list(range(num_cols - len(self.names))) + self.names
175 multi_index_named = False
176 frame.columns = self.names
177 # we only need the frame not the names
178 _, frame = self._do_date_conversions(frame.columns, frame)
179 if self.index_col is not None:
180 index_to_set = self.index_col.copy()
181 for i, item in enumerate(self.index_col):
182 if is_integer(item):
183 index_to_set[i] = frame.columns[item]
184 # String case
185 elif item not in frame.columns:
186 raise ValueError(f"Index {item} invalid")
187
188 # Process dtype for index_col and drop from dtypes
189 if self.dtype is not None:
190 key, new_dtype = (
191 (item, self.dtype.get(item))
192 if self.dtype.get(item) is not None
193 else (frame.columns[item], self.dtype.get(frame.columns[item]))
194 )
195 if new_dtype is not None:
196 frame[key] = frame[key].astype(new_dtype)
197 del self.dtype[key]
198
199 frame.set_index(index_to_set, drop=True, inplace=True)
200 # Clear names if headerless and no name given
201 if self.header is None and not multi_index_named:
202 frame.index.names = [None] * len(frame.index.names)
203
204 if self.dtype is not None:
205 # Ignore non-existent columns from dtype mapping
206 # like other parsers do
207 if isinstance(self.dtype, dict):
208 self.dtype = {
209 k: pandas_dtype(v)
210 for k, v in self.dtype.items()
211 if k in frame.columns
212 }
213 else:
214 self.dtype = pandas_dtype(self.dtype)
215 try:
216 frame = frame.astype(self.dtype)
217 except TypeError as e:
218 # GH#44901 reraise to keep api consistent
219 raise ValueError(e)
220 return frame
221
222 def _validate_usecols(self, usecols) -> None:
223 if lib.is_list_like(usecols) and not all(isinstance(x, str) for x in usecols):
224 raise ValueError(
225 "The pyarrow engine does not allow 'usecols' to be integer "
226 "column positions. Pass a list of string column names instead."
227 )
228 elif callable(usecols):
229 raise ValueError(
230 "The pyarrow engine does not allow 'usecols' to be a callable."
231 )
232
233 def read(self) -> DataFrame:
234 """
235 Reads the contents of a CSV file into a DataFrame and
236 processes it according to the kwargs passed in the
237 constructor.
238
239 Returns
240 -------
241 DataFrame
242 The DataFrame created from the CSV file.
243 """
244 pa = import_optional_dependency("pyarrow")
245 pyarrow_csv = import_optional_dependency("pyarrow.csv")
246 self._get_pyarrow_options()
247
248 try:
249 convert_options = pyarrow_csv.ConvertOptions(**self.convert_options)
250 except TypeError:
251 include = self.convert_options.get("include_columns", None)
252 if include is not None:
253 self._validate_usecols(include)
254
255 nulls = self.convert_options.get("null_values", set())
256 if not lib.is_list_like(nulls) or not all(
257 isinstance(x, str) for x in nulls
258 ):
259 raise TypeError(
260 "The 'pyarrow' engine requires all na_values to be strings"
261 )
262
263 raise
264
265 try:
266 table = pyarrow_csv.read_csv(
267 self.src,
268 read_options=pyarrow_csv.ReadOptions(**self.read_options),
269 parse_options=pyarrow_csv.ParseOptions(**self.parse_options),
270 convert_options=convert_options,
271 )
272 except pa.ArrowInvalid as e:
273 raise ParserError(e) from e
274
275 dtype_backend = self.kwds["dtype_backend"]
276
277 # Convert all pa.null() cols -> float64 (non nullable)
278 # else Int64 (nullable case, see below)
279 if dtype_backend is lib.no_default:
280 new_schema = table.schema
281 new_type = pa.float64()
282 for i, arrow_type in enumerate(table.schema.types):
283 if pa.types.is_null(arrow_type):
284 new_schema = new_schema.set(
285 i, new_schema.field(i).with_type(new_type)
286 )
287
288 table = table.cast(new_schema)
289
290 if dtype_backend == "pyarrow":
291 frame = table.to_pandas(types_mapper=pd.ArrowDtype)
292 elif dtype_backend == "numpy_nullable":
293 # Modify the default mapping to also
294 # map null to Int64 (to match other engines)
295 dtype_mapping = _arrow_dtype_mapping()
296 dtype_mapping[pa.null()] = pd.Int64Dtype()
297 frame = table.to_pandas(types_mapper=dtype_mapping.get)
298 elif using_pyarrow_string_dtype():
299 frame = table.to_pandas(types_mapper=arrow_string_types_mapper())
300
301 else:
302 frame = table.to_pandas()
303 return self._finalize_pandas_output(frame)