1from __future__ import annotations
2
3import re
4from typing import (
5 Callable,
6 Union,
7)
8
9import numpy as np
10
11from pandas._libs import (
12 lib,
13 missing as libmissing,
14)
15from pandas._typing import (
16 Dtype,
17 Scalar,
18 npt,
19)
20from pandas.compat import pa_version_under7p0
21
22from pandas.core.dtypes.common import (
23 is_bool_dtype,
24 is_dtype_equal,
25 is_integer_dtype,
26 is_object_dtype,
27 is_scalar,
28 is_string_dtype,
29 pandas_dtype,
30)
31from pandas.core.dtypes.missing import isna
32
33from pandas.core.arrays.arrow import ArrowExtensionArray
34from pandas.core.arrays.boolean import BooleanDtype
35from pandas.core.arrays.integer import Int64Dtype
36from pandas.core.arrays.numeric import NumericDtype
37from pandas.core.arrays.string_ import (
38 BaseStringArray,
39 StringDtype,
40)
41from pandas.core.strings.object_array import ObjectStringArrayMixin
42
43if not pa_version_under7p0:
44 import pyarrow as pa
45 import pyarrow.compute as pc
46
47 from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
48
49ArrowStringScalarOrNAT = Union[str, libmissing.NAType]
50
51
52def _chk_pyarrow_available() -> None:
53 if pa_version_under7p0:
54 msg = "pyarrow>=7.0.0 is required for PyArrow backed ArrowExtensionArray."
55 raise ImportError(msg)
56
57
58# TODO: Inherit directly from BaseStringArrayMethods. Currently we inherit from
59# ObjectStringArrayMixin because we want to have the object-dtype based methods as
60# fallback for the ones that pyarrow doesn't yet support
61
62
63class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringArray):
64 """
65 Extension array for string data in a ``pyarrow.ChunkedArray``.
66
67 .. versionadded:: 1.2.0
68
69 .. warning::
70
71 ArrowStringArray is considered experimental. The implementation and
72 parts of the API may change without warning.
73
74 Parameters
75 ----------
76 values : pyarrow.Array or pyarrow.ChunkedArray
77 The array of data.
78
79 Attributes
80 ----------
81 None
82
83 Methods
84 -------
85 None
86
87 See Also
88 --------
89 :func:`pandas.array`
90 The recommended function for creating a ArrowStringArray.
91 Series.str
92 The string methods are available on Series backed by
93 a ArrowStringArray.
94
95 Notes
96 -----
97 ArrowStringArray returns a BooleanArray for comparison methods.
98
99 Examples
100 --------
101 >>> pd.array(['This is', 'some text', None, 'data.'], dtype="string[pyarrow]")
102 <ArrowStringArray>
103 ['This is', 'some text', <NA>, 'data.']
104 Length: 4, dtype: string
105 """
106
107 # error: Incompatible types in assignment (expression has type "StringDtype",
108 # base class "ArrowExtensionArray" defined the type as "ArrowDtype")
109 _dtype: StringDtype # type: ignore[assignment]
110
111 def __init__(self, values) -> None:
112 super().__init__(values)
113 self._dtype = StringDtype(storage="pyarrow")
114
115 if not pa.types.is_string(self._data.type):
116 raise ValueError(
117 "ArrowStringArray requires a PyArrow (chunked) array of string type"
118 )
119
120 def __len__(self) -> int:
121 """
122 Length of this array.
123
124 Returns
125 -------
126 length : int
127 """
128 return len(self._data)
129
130 @classmethod
131 def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False):
132 from pandas.core.arrays.masked import BaseMaskedArray
133
134 _chk_pyarrow_available()
135
136 if dtype and not (isinstance(dtype, str) and dtype == "string"):
137 dtype = pandas_dtype(dtype)
138 assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow"
139
140 if isinstance(scalars, BaseMaskedArray):
141 # avoid costly conversion to object dtype in ensure_string_array and
142 # numerical issues with Float32Dtype
143 na_values = scalars._mask
144 result = scalars._data
145 result = lib.ensure_string_array(result, copy=copy, convert_na_value=False)
146 return cls(pa.array(result, mask=na_values, type=pa.string()))
147 elif isinstance(scalars, (pa.Array, pa.ChunkedArray)):
148 return cls(pc.cast(scalars, pa.string()))
149
150 # convert non-na-likes to str
151 result = lib.ensure_string_array(scalars, copy=copy)
152 return cls(pa.array(result, type=pa.string(), from_pandas=True))
153
154 @classmethod
155 def _from_sequence_of_strings(
156 cls, strings, dtype: Dtype | None = None, copy: bool = False
157 ):
158 return cls._from_sequence(strings, dtype=dtype, copy=copy)
159
160 @property
161 def dtype(self) -> StringDtype: # type: ignore[override]
162 """
163 An instance of 'string[pyarrow]'.
164 """
165 return self._dtype
166
167 def insert(self, loc: int, item) -> ArrowStringArray:
168 if not isinstance(item, str) and item is not libmissing.NA:
169 raise TypeError("Scalar must be NA or str")
170 return super().insert(loc, item)
171
172 def _maybe_convert_setitem_value(self, value):
173 """Maybe convert value to be pyarrow compatible."""
174 if is_scalar(value):
175 if isna(value):
176 value = None
177 elif not isinstance(value, str):
178 raise TypeError("Scalar must be NA or str")
179 else:
180 value = np.array(value, dtype=object, copy=True)
181 value[isna(value)] = None
182 for v in value:
183 if not (v is None or isinstance(v, str)):
184 raise TypeError("Scalar must be NA or str")
185 return super()._maybe_convert_setitem_value(value)
186
187 def isin(self, values) -> npt.NDArray[np.bool_]:
188 value_set = [
189 pa_scalar.as_py()
190 for pa_scalar in [pa.scalar(value, from_pandas=True) for value in values]
191 if pa_scalar.type in (pa.string(), pa.null())
192 ]
193
194 # short-circuit to return all False array.
195 if not len(value_set):
196 return np.zeros(len(self), dtype=bool)
197
198 result = pc.is_in(self._data, value_set=pa.array(value_set))
199 # pyarrow 2.0.0 returned nulls, so we explicily specify dtype to convert nulls
200 # to False
201 return np.array(result, dtype=np.bool_)
202
203 def astype(self, dtype, copy: bool = True):
204 dtype = pandas_dtype(dtype)
205
206 if is_dtype_equal(dtype, self.dtype):
207 if copy:
208 return self.copy()
209 return self
210 elif isinstance(dtype, NumericDtype):
211 data = self._data.cast(pa.from_numpy_dtype(dtype.numpy_dtype))
212 return dtype.__from_arrow__(data)
213 elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.floating):
214 return self.to_numpy(dtype=dtype, na_value=np.nan)
215
216 return super().astype(dtype, copy=copy)
217
218 # ------------------------------------------------------------------------
219 # String methods interface
220
221 # error: Incompatible types in assignment (expression has type "NAType",
222 # base class "ObjectStringArrayMixin" defined the type as "float")
223 _str_na_value = libmissing.NA # type: ignore[assignment]
224
225 def _str_map(
226 self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
227 ):
228 # TODO: de-duplicate with StringArray method. This method is moreless copy and
229 # paste.
230
231 from pandas.arrays import (
232 BooleanArray,
233 IntegerArray,
234 )
235
236 if dtype is None:
237 dtype = self.dtype
238 if na_value is None:
239 na_value = self.dtype.na_value
240
241 mask = isna(self)
242 arr = np.asarray(self)
243
244 if is_integer_dtype(dtype) or is_bool_dtype(dtype):
245 constructor: type[IntegerArray] | type[BooleanArray]
246 if is_integer_dtype(dtype):
247 constructor = IntegerArray
248 else:
249 constructor = BooleanArray
250
251 na_value_is_na = isna(na_value)
252 if na_value_is_na:
253 na_value = 1
254 result = lib.map_infer_mask(
255 arr,
256 f,
257 mask.view("uint8"),
258 convert=False,
259 na_value=na_value,
260 # error: Argument 1 to "dtype" has incompatible type
261 # "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
262 # "Type[object]"
263 dtype=np.dtype(dtype), # type: ignore[arg-type]
264 )
265
266 if not na_value_is_na:
267 mask[:] = False
268
269 return constructor(result, mask)
270
271 elif is_string_dtype(dtype) and not is_object_dtype(dtype):
272 # i.e. StringDtype
273 result = lib.map_infer_mask(
274 arr, f, mask.view("uint8"), convert=False, na_value=na_value
275 )
276 result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True)
277 return type(self)(result)
278 else:
279 # This is when the result type is object. We reach this when
280 # -> We know the result type is truly object (e.g. .encode returns bytes
281 # or .findall returns a list).
282 # -> We don't know the result type. E.g. `.get` can return anything.
283 return lib.map_infer_mask(arr, f, mask.view("uint8"))
284
285 def _str_contains(
286 self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
287 ):
288 if flags:
289 fallback_performancewarning()
290 return super()._str_contains(pat, case, flags, na, regex)
291
292 if regex:
293 if case is False:
294 fallback_performancewarning()
295 return super()._str_contains(pat, case, flags, na, regex)
296 else:
297 result = pc.match_substring_regex(self._data, pat)
298 else:
299 if case:
300 result = pc.match_substring(self._data, pat)
301 else:
302 result = pc.match_substring(pc.utf8_upper(self._data), pat.upper())
303 result = BooleanDtype().__from_arrow__(result)
304 if not isna(na):
305 result[isna(result)] = bool(na)
306 return result
307
308 def _str_startswith(self, pat: str, na=None):
309 pat = f"^{re.escape(pat)}"
310 return self._str_contains(pat, na=na, regex=True)
311
312 def _str_endswith(self, pat: str, na=None):
313 pat = f"{re.escape(pat)}$"
314 return self._str_contains(pat, na=na, regex=True)
315
316 def _str_replace(
317 self,
318 pat: str | re.Pattern,
319 repl: str | Callable,
320 n: int = -1,
321 case: bool = True,
322 flags: int = 0,
323 regex: bool = True,
324 ):
325 if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
326 fallback_performancewarning()
327 return super()._str_replace(pat, repl, n, case, flags, regex)
328
329 func = pc.replace_substring_regex if regex else pc.replace_substring
330 result = func(self._data, pattern=pat, replacement=repl, max_replacements=n)
331 return type(self)(result)
332
333 def _str_match(
334 self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
335 ):
336 if not pat.startswith("^"):
337 pat = f"^{pat}"
338 return self._str_contains(pat, case, flags, na, regex=True)
339
340 def _str_fullmatch(
341 self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
342 ):
343 if not pat.endswith("$") or pat.endswith("//$"):
344 pat = f"{pat}$"
345 return self._str_match(pat, case, flags, na)
346
347 def _str_isalnum(self):
348 result = pc.utf8_is_alnum(self._data)
349 return BooleanDtype().__from_arrow__(result)
350
351 def _str_isalpha(self):
352 result = pc.utf8_is_alpha(self._data)
353 return BooleanDtype().__from_arrow__(result)
354
355 def _str_isdecimal(self):
356 result = pc.utf8_is_decimal(self._data)
357 return BooleanDtype().__from_arrow__(result)
358
359 def _str_isdigit(self):
360 result = pc.utf8_is_digit(self._data)
361 return BooleanDtype().__from_arrow__(result)
362
363 def _str_islower(self):
364 result = pc.utf8_is_lower(self._data)
365 return BooleanDtype().__from_arrow__(result)
366
367 def _str_isnumeric(self):
368 result = pc.utf8_is_numeric(self._data)
369 return BooleanDtype().__from_arrow__(result)
370
371 def _str_isspace(self):
372 result = pc.utf8_is_space(self._data)
373 return BooleanDtype().__from_arrow__(result)
374
375 def _str_istitle(self):
376 result = pc.utf8_is_title(self._data)
377 return BooleanDtype().__from_arrow__(result)
378
379 def _str_isupper(self):
380 result = pc.utf8_is_upper(self._data)
381 return BooleanDtype().__from_arrow__(result)
382
383 def _str_len(self):
384 result = pc.utf8_length(self._data)
385 return Int64Dtype().__from_arrow__(result)
386
387 def _str_lower(self):
388 return type(self)(pc.utf8_lower(self._data))
389
390 def _str_upper(self):
391 return type(self)(pc.utf8_upper(self._data))
392
393 def _str_strip(self, to_strip=None):
394 if to_strip is None:
395 result = pc.utf8_trim_whitespace(self._data)
396 else:
397 result = pc.utf8_trim(self._data, characters=to_strip)
398 return type(self)(result)
399
400 def _str_lstrip(self, to_strip=None):
401 if to_strip is None:
402 result = pc.utf8_ltrim_whitespace(self._data)
403 else:
404 result = pc.utf8_ltrim(self._data, characters=to_strip)
405 return type(self)(result)
406
407 def _str_rstrip(self, to_strip=None):
408 if to_strip is None:
409 result = pc.utf8_rtrim_whitespace(self._data)
410 else:
411 result = pc.utf8_rtrim(self._data, characters=to_strip)
412 return type(self)(result)