1from __future__ import annotations
2
3from pandas._typing import ReadBuffer
4from pandas.compat._optional import import_optional_dependency
5
6from pandas.core.dtypes.inference import is_integer
7
8import pandas as pd
9from pandas import DataFrame
10
11from pandas.io._util import _arrow_dtype_mapping
12from pandas.io.parsers.base_parser import ParserBase
13
14
15class ArrowParserWrapper(ParserBase):
16 """
17 Wrapper for the pyarrow engine for read_csv()
18 """
19
20 def __init__(self, src: ReadBuffer[bytes], **kwds) -> None:
21 super().__init__(kwds)
22 self.kwds = kwds
23 self.src = src
24
25 self._parse_kwds()
26
27 def _parse_kwds(self):
28 """
29 Validates keywords before passing to pyarrow.
30 """
31 encoding: str | None = self.kwds.get("encoding")
32 self.encoding = "utf-8" if encoding is None else encoding
33
34 self.usecols, self.usecols_dtype = self._validate_usecols_arg(
35 self.kwds["usecols"]
36 )
37 na_values = self.kwds["na_values"]
38 if isinstance(na_values, dict):
39 raise ValueError(
40 "The pyarrow engine doesn't support passing a dict for na_values"
41 )
42 self.na_values = list(self.kwds["na_values"])
43
44 def _get_pyarrow_options(self) -> None:
45 """
46 Rename some arguments to pass to pyarrow
47 """
48 mapping = {
49 "usecols": "include_columns",
50 "na_values": "null_values",
51 "escapechar": "escape_char",
52 "skip_blank_lines": "ignore_empty_lines",
53 "decimal": "decimal_point",
54 }
55 for pandas_name, pyarrow_name in mapping.items():
56 if pandas_name in self.kwds and self.kwds.get(pandas_name) is not None:
57 self.kwds[pyarrow_name] = self.kwds.pop(pandas_name)
58
59 self.parse_options = {
60 option_name: option_value
61 for option_name, option_value in self.kwds.items()
62 if option_value is not None
63 and option_name
64 in ("delimiter", "quote_char", "escape_char", "ignore_empty_lines")
65 }
66 self.convert_options = {
67 option_name: option_value
68 for option_name, option_value in self.kwds.items()
69 if option_value is not None
70 and option_name
71 in (
72 "include_columns",
73 "null_values",
74 "true_values",
75 "false_values",
76 "decimal_point",
77 )
78 }
79 self.read_options = {
80 "autogenerate_column_names": self.header is None,
81 "skip_rows": self.header
82 if self.header is not None
83 else self.kwds["skiprows"],
84 "encoding": self.encoding,
85 }
86
87 def _finalize_pandas_output(self, frame: DataFrame) -> DataFrame:
88 """
89 Processes data read in based on kwargs.
90
91 Parameters
92 ----------
93 frame: DataFrame
94 The DataFrame to process.
95
96 Returns
97 -------
98 DataFrame
99 The processed DataFrame.
100 """
101 num_cols = len(frame.columns)
102 multi_index_named = True
103 if self.header is None:
104 if self.names is None:
105 if self.header is None:
106 self.names = range(num_cols)
107 if len(self.names) != num_cols:
108 # usecols is passed through to pyarrow, we only handle index col here
109 # The only way self.names is not the same length as number of cols is
110 # if we have int index_col. We should just pad the names(they will get
111 # removed anyways) to expected length then.
112 self.names = list(range(num_cols - len(self.names))) + self.names
113 multi_index_named = False
114 frame.columns = self.names
115 # we only need the frame not the names
116 frame.columns, frame = self._do_date_conversions(frame.columns, frame)
117 if self.index_col is not None:
118 for i, item in enumerate(self.index_col):
119 if is_integer(item):
120 self.index_col[i] = frame.columns[item]
121 else:
122 # String case
123 if item not in frame.columns:
124 raise ValueError(f"Index {item} invalid")
125 frame.set_index(self.index_col, drop=True, inplace=True)
126 # Clear names if headerless and no name given
127 if self.header is None and not multi_index_named:
128 frame.index.names = [None] * len(frame.index.names)
129
130 if self.kwds.get("dtype") is not None:
131 try:
132 frame = frame.astype(self.kwds.get("dtype"))
133 except TypeError as e:
134 # GH#44901 reraise to keep api consistent
135 raise ValueError(e)
136 return frame
137
138 def read(self) -> DataFrame:
139 """
140 Reads the contents of a CSV file into a DataFrame and
141 processes it according to the kwargs passed in the
142 constructor.
143
144 Returns
145 -------
146 DataFrame
147 The DataFrame created from the CSV file.
148 """
149 pyarrow_csv = import_optional_dependency("pyarrow.csv")
150 self._get_pyarrow_options()
151
152 table = pyarrow_csv.read_csv(
153 self.src,
154 read_options=pyarrow_csv.ReadOptions(**self.read_options),
155 parse_options=pyarrow_csv.ParseOptions(**self.parse_options),
156 convert_options=pyarrow_csv.ConvertOptions(**self.convert_options),
157 )
158 if self.kwds["dtype_backend"] == "pyarrow":
159 frame = table.to_pandas(types_mapper=pd.ArrowDtype)
160 elif self.kwds["dtype_backend"] == "numpy_nullable":
161 frame = table.to_pandas(types_mapper=_arrow_dtype_mapping().get)
162 else:
163 frame = table.to_pandas()
164 return self._finalize_pandas_output(frame)