1from __future__ import annotations
2
3from functools import partial
4import operator
5import re
6from typing import (
7 TYPE_CHECKING,
8 Callable,
9 Union,
10)
11import warnings
12
13import numpy as np
14
15from pandas._libs import (
16 lib,
17 missing as libmissing,
18)
19from pandas.compat import (
20 pa_version_under10p1,
21 pa_version_under13p0,
22)
23from pandas.util._exceptions import find_stack_level
24
25from pandas.core.dtypes.common import (
26 is_bool_dtype,
27 is_integer_dtype,
28 is_object_dtype,
29 is_scalar,
30 is_string_dtype,
31 pandas_dtype,
32)
33from pandas.core.dtypes.missing import isna
34
35from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
36from pandas.core.arrays.arrow import ArrowExtensionArray
37from pandas.core.arrays.boolean import BooleanDtype
38from pandas.core.arrays.integer import Int64Dtype
39from pandas.core.arrays.numeric import NumericDtype
40from pandas.core.arrays.string_ import (
41 BaseStringArray,
42 StringDtype,
43)
44from pandas.core.ops import invalid_comparison
45from pandas.core.strings.object_array import ObjectStringArrayMixin
46
47if not pa_version_under10p1:
48 import pyarrow as pa
49 import pyarrow.compute as pc
50
51 from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
52
53
54if TYPE_CHECKING:
55 from collections.abc import Sequence
56
57 from pandas._typing import (
58 ArrayLike,
59 AxisInt,
60 Dtype,
61 Scalar,
62 npt,
63 )
64
65 from pandas import Series
66
67
68ArrowStringScalarOrNAT = Union[str, libmissing.NAType]
69
70
71def _chk_pyarrow_available() -> None:
72 if pa_version_under10p1:
73 msg = "pyarrow>=10.0.1 is required for PyArrow backed ArrowExtensionArray."
74 raise ImportError(msg)
75
76
77# TODO: Inherit directly from BaseStringArrayMethods. Currently we inherit from
78# ObjectStringArrayMixin because we want to have the object-dtype based methods as
79# fallback for the ones that pyarrow doesn't yet support
80
81
82class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringArray):
83 """
84 Extension array for string data in a ``pyarrow.ChunkedArray``.
85
86 .. warning::
87
88 ArrowStringArray is considered experimental. The implementation and
89 parts of the API may change without warning.
90
91 Parameters
92 ----------
93 values : pyarrow.Array or pyarrow.ChunkedArray
94 The array of data.
95
96 Attributes
97 ----------
98 None
99
100 Methods
101 -------
102 None
103
104 See Also
105 --------
106 :func:`pandas.array`
107 The recommended function for creating a ArrowStringArray.
108 Series.str
109 The string methods are available on Series backed by
110 a ArrowStringArray.
111
112 Notes
113 -----
114 ArrowStringArray returns a BooleanArray for comparison methods.
115
116 Examples
117 --------
118 >>> pd.array(['This is', 'some text', None, 'data.'], dtype="string[pyarrow]")
119 <ArrowStringArray>
120 ['This is', 'some text', <NA>, 'data.']
121 Length: 4, dtype: string
122 """
123
124 # error: Incompatible types in assignment (expression has type "StringDtype",
125 # base class "ArrowExtensionArray" defined the type as "ArrowDtype")
126 _dtype: StringDtype # type: ignore[assignment]
127 _storage = "pyarrow"
128
129 def __init__(self, values) -> None:
130 _chk_pyarrow_available()
131 if isinstance(values, (pa.Array, pa.ChunkedArray)) and pa.types.is_string(
132 values.type
133 ):
134 values = pc.cast(values, pa.large_string())
135
136 super().__init__(values)
137 self._dtype = StringDtype(storage=self._storage)
138
139 if not pa.types.is_large_string(self._pa_array.type) and not (
140 pa.types.is_dictionary(self._pa_array.type)
141 and pa.types.is_large_string(self._pa_array.type.value_type)
142 ):
143 raise ValueError(
144 "ArrowStringArray requires a PyArrow (chunked) array of "
145 "large_string type"
146 )
147
148 @classmethod
149 def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
150 pa_scalar = super()._box_pa_scalar(value, pa_type)
151 if pa.types.is_string(pa_scalar.type) and pa_type is None:
152 pa_scalar = pc.cast(pa_scalar, pa.large_string())
153 return pa_scalar
154
155 @classmethod
156 def _box_pa_array(
157 cls, value, pa_type: pa.DataType | None = None, copy: bool = False
158 ) -> pa.Array | pa.ChunkedArray:
159 pa_array = super()._box_pa_array(value, pa_type)
160 if pa.types.is_string(pa_array.type) and pa_type is None:
161 pa_array = pc.cast(pa_array, pa.large_string())
162 return pa_array
163
164 def __len__(self) -> int:
165 """
166 Length of this array.
167
168 Returns
169 -------
170 length : int
171 """
172 return len(self._pa_array)
173
174 @classmethod
175 def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False):
176 from pandas.core.arrays.masked import BaseMaskedArray
177
178 _chk_pyarrow_available()
179
180 if dtype and not (isinstance(dtype, str) and dtype == "string"):
181 dtype = pandas_dtype(dtype)
182 assert isinstance(dtype, StringDtype) and dtype.storage in (
183 "pyarrow",
184 "pyarrow_numpy",
185 )
186
187 if isinstance(scalars, BaseMaskedArray):
188 # avoid costly conversion to object dtype in ensure_string_array and
189 # numerical issues with Float32Dtype
190 na_values = scalars._mask
191 result = scalars._data
192 result = lib.ensure_string_array(result, copy=copy, convert_na_value=False)
193 return cls(pa.array(result, mask=na_values, type=pa.large_string()))
194 elif isinstance(scalars, (pa.Array, pa.ChunkedArray)):
195 return cls(pc.cast(scalars, pa.large_string()))
196
197 # convert non-na-likes to str
198 result = lib.ensure_string_array(scalars, copy=copy)
199 return cls(pa.array(result, type=pa.large_string(), from_pandas=True))
200
201 @classmethod
202 def _from_sequence_of_strings(
203 cls, strings, dtype: Dtype | None = None, copy: bool = False
204 ):
205 return cls._from_sequence(strings, dtype=dtype, copy=copy)
206
207 @property
208 def dtype(self) -> StringDtype: # type: ignore[override]
209 """
210 An instance of 'string[pyarrow]'.
211 """
212 return self._dtype
213
214 def insert(self, loc: int, item) -> ArrowStringArray:
215 if not isinstance(item, str) and item is not libmissing.NA:
216 raise TypeError("Scalar must be NA or str")
217 return super().insert(loc, item)
218
219 @classmethod
220 def _result_converter(cls, values, na=None):
221 return BooleanDtype().__from_arrow__(values)
222
223 def _maybe_convert_setitem_value(self, value):
224 """Maybe convert value to be pyarrow compatible."""
225 if is_scalar(value):
226 if isna(value):
227 value = None
228 elif not isinstance(value, str):
229 raise TypeError("Scalar must be NA or str")
230 else:
231 value = np.array(value, dtype=object, copy=True)
232 value[isna(value)] = None
233 for v in value:
234 if not (v is None or isinstance(v, str)):
235 raise TypeError("Scalar must be NA or str")
236 return super()._maybe_convert_setitem_value(value)
237
238 def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
239 value_set = [
240 pa_scalar.as_py()
241 for pa_scalar in [pa.scalar(value, from_pandas=True) for value in values]
242 if pa_scalar.type in (pa.string(), pa.null(), pa.large_string())
243 ]
244
245 # short-circuit to return all False array.
246 if not len(value_set):
247 return np.zeros(len(self), dtype=bool)
248
249 result = pc.is_in(
250 self._pa_array, value_set=pa.array(value_set, type=self._pa_array.type)
251 )
252 # pyarrow 2.0.0 returned nulls, so we explicily specify dtype to convert nulls
253 # to False
254 return np.array(result, dtype=np.bool_)
255
256 def astype(self, dtype, copy: bool = True):
257 dtype = pandas_dtype(dtype)
258
259 if dtype == self.dtype:
260 if copy:
261 return self.copy()
262 return self
263 elif isinstance(dtype, NumericDtype):
264 data = self._pa_array.cast(pa.from_numpy_dtype(dtype.numpy_dtype))
265 return dtype.__from_arrow__(data)
266 elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.floating):
267 return self.to_numpy(dtype=dtype, na_value=np.nan)
268
269 return super().astype(dtype, copy=copy)
270
271 @property
272 def _data(self):
273 # dask accesses ._data directlys
274 warnings.warn(
275 f"{type(self).__name__}._data is a deprecated and will be removed "
276 "in a future version, use ._pa_array instead",
277 FutureWarning,
278 stacklevel=find_stack_level(),
279 )
280 return self._pa_array
281
282 # ------------------------------------------------------------------------
283 # String methods interface
284
285 # error: Incompatible types in assignment (expression has type "NAType",
286 # base class "ObjectStringArrayMixin" defined the type as "float")
287 _str_na_value = libmissing.NA # type: ignore[assignment]
288
289 def _str_map(
290 self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
291 ):
292 # TODO: de-duplicate with StringArray method. This method is moreless copy and
293 # paste.
294
295 from pandas.arrays import (
296 BooleanArray,
297 IntegerArray,
298 )
299
300 if dtype is None:
301 dtype = self.dtype
302 if na_value is None:
303 na_value = self.dtype.na_value
304
305 mask = isna(self)
306 arr = np.asarray(self)
307
308 if is_integer_dtype(dtype) or is_bool_dtype(dtype):
309 constructor: type[IntegerArray | BooleanArray]
310 if is_integer_dtype(dtype):
311 constructor = IntegerArray
312 else:
313 constructor = BooleanArray
314
315 na_value_is_na = isna(na_value)
316 if na_value_is_na:
317 na_value = 1
318 result = lib.map_infer_mask(
319 arr,
320 f,
321 mask.view("uint8"),
322 convert=False,
323 na_value=na_value,
324 # error: Argument 1 to "dtype" has incompatible type
325 # "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
326 # "Type[object]"
327 dtype=np.dtype(dtype), # type: ignore[arg-type]
328 )
329
330 if not na_value_is_na:
331 mask[:] = False
332
333 return constructor(result, mask)
334
335 elif is_string_dtype(dtype) and not is_object_dtype(dtype):
336 # i.e. StringDtype
337 result = lib.map_infer_mask(
338 arr, f, mask.view("uint8"), convert=False, na_value=na_value
339 )
340 result = pa.array(
341 result, mask=mask, type=pa.large_string(), from_pandas=True
342 )
343 return type(self)(result)
344 else:
345 # This is when the result type is object. We reach this when
346 # -> We know the result type is truly object (e.g. .encode returns bytes
347 # or .findall returns a list).
348 # -> We don't know the result type. E.g. `.get` can return anything.
349 return lib.map_infer_mask(arr, f, mask.view("uint8"))
350
351 def _str_contains(
352 self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
353 ):
354 if flags:
355 fallback_performancewarning()
356 return super()._str_contains(pat, case, flags, na, regex)
357
358 if regex:
359 result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
360 else:
361 result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
362 result = self._result_converter(result, na=na)
363 if not isna(na):
364 result[isna(result)] = bool(na)
365 return result
366
367 def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
368 if isinstance(pat, str):
369 result = pc.starts_with(self._pa_array, pattern=pat)
370 else:
371 if len(pat) == 0:
372 # mimic existing behaviour of string extension array
373 # and python string method
374 result = pa.array(
375 np.zeros(len(self._pa_array), dtype=bool), mask=isna(self._pa_array)
376 )
377 else:
378 result = pc.starts_with(self._pa_array, pattern=pat[0])
379
380 for p in pat[1:]:
381 result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
382 if not isna(na):
383 result = result.fill_null(na)
384 return self._result_converter(result)
385
386 def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
387 if isinstance(pat, str):
388 result = pc.ends_with(self._pa_array, pattern=pat)
389 else:
390 if len(pat) == 0:
391 # mimic existing behaviour of string extension array
392 # and python string method
393 result = pa.array(
394 np.zeros(len(self._pa_array), dtype=bool), mask=isna(self._pa_array)
395 )
396 else:
397 result = pc.ends_with(self._pa_array, pattern=pat[0])
398
399 for p in pat[1:]:
400 result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
401 if not isna(na):
402 result = result.fill_null(na)
403 return self._result_converter(result)
404
405 def _str_replace(
406 self,
407 pat: str | re.Pattern,
408 repl: str | Callable,
409 n: int = -1,
410 case: bool = True,
411 flags: int = 0,
412 regex: bool = True,
413 ):
414 if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
415 fallback_performancewarning()
416 return super()._str_replace(pat, repl, n, case, flags, regex)
417
418 func = pc.replace_substring_regex if regex else pc.replace_substring
419 result = func(self._pa_array, pattern=pat, replacement=repl, max_replacements=n)
420 return type(self)(result)
421
422 def _str_repeat(self, repeats: int | Sequence[int]):
423 if not isinstance(repeats, int):
424 return super()._str_repeat(repeats)
425 else:
426 return type(self)(pc.binary_repeat(self._pa_array, repeats))
427
428 def _str_match(
429 self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
430 ):
431 if not pat.startswith("^"):
432 pat = f"^{pat}"
433 return self._str_contains(pat, case, flags, na, regex=True)
434
435 def _str_fullmatch(
436 self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
437 ):
438 if not pat.endswith("$") or pat.endswith("\\$"):
439 pat = f"{pat}$"
440 return self._str_match(pat, case, flags, na)
441
442 def _str_slice(
443 self, start: int | None = None, stop: int | None = None, step: int | None = None
444 ):
445 if stop is None:
446 return super()._str_slice(start, stop, step)
447 if start is None:
448 start = 0
449 if step is None:
450 step = 1
451 return type(self)(
452 pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step)
453 )
454
455 def _str_isalnum(self):
456 result = pc.utf8_is_alnum(self._pa_array)
457 return self._result_converter(result)
458
459 def _str_isalpha(self):
460 result = pc.utf8_is_alpha(self._pa_array)
461 return self._result_converter(result)
462
463 def _str_isdecimal(self):
464 result = pc.utf8_is_decimal(self._pa_array)
465 return self._result_converter(result)
466
467 def _str_isdigit(self):
468 result = pc.utf8_is_digit(self._pa_array)
469 return self._result_converter(result)
470
471 def _str_islower(self):
472 result = pc.utf8_is_lower(self._pa_array)
473 return self._result_converter(result)
474
475 def _str_isnumeric(self):
476 result = pc.utf8_is_numeric(self._pa_array)
477 return self._result_converter(result)
478
479 def _str_isspace(self):
480 result = pc.utf8_is_space(self._pa_array)
481 return self._result_converter(result)
482
483 def _str_istitle(self):
484 result = pc.utf8_is_title(self._pa_array)
485 return self._result_converter(result)
486
487 def _str_isupper(self):
488 result = pc.utf8_is_upper(self._pa_array)
489 return self._result_converter(result)
490
491 def _str_len(self):
492 result = pc.utf8_length(self._pa_array)
493 return self._convert_int_dtype(result)
494
495 def _str_lower(self):
496 return type(self)(pc.utf8_lower(self._pa_array))
497
498 def _str_upper(self):
499 return type(self)(pc.utf8_upper(self._pa_array))
500
501 def _str_strip(self, to_strip=None):
502 if to_strip is None:
503 result = pc.utf8_trim_whitespace(self._pa_array)
504 else:
505 result = pc.utf8_trim(self._pa_array, characters=to_strip)
506 return type(self)(result)
507
508 def _str_lstrip(self, to_strip=None):
509 if to_strip is None:
510 result = pc.utf8_ltrim_whitespace(self._pa_array)
511 else:
512 result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
513 return type(self)(result)
514
515 def _str_rstrip(self, to_strip=None):
516 if to_strip is None:
517 result = pc.utf8_rtrim_whitespace(self._pa_array)
518 else:
519 result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
520 return type(self)(result)
521
522 def _str_removeprefix(self, prefix: str):
523 if not pa_version_under13p0:
524 starts_with = pc.starts_with(self._pa_array, pattern=prefix)
525 removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
526 result = pc.if_else(starts_with, removed, self._pa_array)
527 return type(self)(result)
528 return super()._str_removeprefix(prefix)
529
530 def _str_removesuffix(self, suffix: str):
531 ends_with = pc.ends_with(self._pa_array, pattern=suffix)
532 removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
533 result = pc.if_else(ends_with, removed, self._pa_array)
534 return type(self)(result)
535
536 def _str_count(self, pat: str, flags: int = 0):
537 if flags:
538 return super()._str_count(pat, flags)
539 result = pc.count_substring_regex(self._pa_array, pat)
540 return self._convert_int_dtype(result)
541
542 def _str_find(self, sub: str, start: int = 0, end: int | None = None):
543 if start != 0 and end is not None:
544 slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
545 result = pc.find_substring(slices, sub)
546 not_found = pc.equal(result, -1)
547 offset_result = pc.add(result, end - start)
548 result = pc.if_else(not_found, result, offset_result)
549 elif start == 0 and end is None:
550 slices = self._pa_array
551 result = pc.find_substring(slices, sub)
552 else:
553 return super()._str_find(sub, start, end)
554 return self._convert_int_dtype(result)
555
556 def _str_get_dummies(self, sep: str = "|"):
557 dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
558 if len(labels) == 0:
559 return np.empty(shape=(0, 0), dtype=np.int64), labels
560 dummies = np.vstack(dummies_pa.to_numpy())
561 return dummies.astype(np.int64, copy=False), labels
562
563 def _convert_int_dtype(self, result):
564 return Int64Dtype().__from_arrow__(result)
565
566 def _reduce(
567 self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
568 ):
569 result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
570 if name in ("argmin", "argmax") and isinstance(result, pa.Array):
571 return self._convert_int_dtype(result)
572 elif isinstance(result, pa.Array):
573 return type(self)(result)
574 else:
575 return result
576
577 def _rank(
578 self,
579 *,
580 axis: AxisInt = 0,
581 method: str = "average",
582 na_option: str = "keep",
583 ascending: bool = True,
584 pct: bool = False,
585 ):
586 """
587 See Series.rank.__doc__.
588 """
589 return self._convert_int_dtype(
590 self._rank_calc(
591 axis=axis,
592 method=method,
593 na_option=na_option,
594 ascending=ascending,
595 pct=pct,
596 )
597 )
598
599
600class ArrowStringArrayNumpySemantics(ArrowStringArray):
601 _storage = "pyarrow_numpy"
602
603 @classmethod
604 def _result_converter(cls, values, na=None):
605 if not isna(na):
606 values = values.fill_null(bool(na))
607 return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
608
609 def __getattribute__(self, item):
610 # ArrowStringArray and we both inherit from ArrowExtensionArray, which
611 # creates inheritance problems (Diamond inheritance)
612 if item in ArrowStringArrayMixin.__dict__ and item not in (
613 "_pa_array",
614 "__dict__",
615 ):
616 return partial(getattr(ArrowStringArrayMixin, item), self)
617 return super().__getattribute__(item)
618
619 def _str_map(
620 self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
621 ):
622 if dtype is None:
623 dtype = self.dtype
624 if na_value is None:
625 na_value = self.dtype.na_value
626
627 mask = isna(self)
628 arr = np.asarray(self)
629
630 if is_integer_dtype(dtype) or is_bool_dtype(dtype):
631 if is_integer_dtype(dtype):
632 na_value = np.nan
633 else:
634 na_value = False
635 try:
636 result = lib.map_infer_mask(
637 arr,
638 f,
639 mask.view("uint8"),
640 convert=False,
641 na_value=na_value,
642 dtype=np.dtype(dtype), # type: ignore[arg-type]
643 )
644 return result
645
646 except ValueError:
647 result = lib.map_infer_mask(
648 arr,
649 f,
650 mask.view("uint8"),
651 convert=False,
652 na_value=na_value,
653 )
654 if convert and result.dtype == object:
655 result = lib.maybe_convert_objects(result)
656 return result
657
658 elif is_string_dtype(dtype) and not is_object_dtype(dtype):
659 # i.e. StringDtype
660 result = lib.map_infer_mask(
661 arr, f, mask.view("uint8"), convert=False, na_value=na_value
662 )
663 result = pa.array(
664 result, mask=mask, type=pa.large_string(), from_pandas=True
665 )
666 return type(self)(result)
667 else:
668 # This is when the result type is object. We reach this when
669 # -> We know the result type is truly object (e.g. .encode returns bytes
670 # or .findall returns a list).
671 # -> We don't know the result type. E.g. `.get` can return anything.
672 return lib.map_infer_mask(arr, f, mask.view("uint8"))
673
674 def _convert_int_dtype(self, result):
675 if isinstance(result, pa.Array):
676 result = result.to_numpy(zero_copy_only=False)
677 else:
678 result = result.to_numpy()
679 if result.dtype == np.int32:
680 result = result.astype(np.int64)
681 return result
682
683 def _cmp_method(self, other, op):
684 try:
685 result = super()._cmp_method(other, op)
686 except pa.ArrowNotImplementedError:
687 return invalid_comparison(self, other, op)
688 if op == operator.ne:
689 return result.to_numpy(np.bool_, na_value=True)
690 else:
691 return result.to_numpy(np.bool_, na_value=False)
692
693 def value_counts(self, dropna: bool = True) -> Series:
694 from pandas import Series
695
696 result = super().value_counts(dropna)
697 return Series(
698 result._values.to_numpy(), index=result.index, name=result.name, copy=False
699 )
700
701 def _reduce(
702 self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
703 ):
704 if name in ["any", "all"]:
705 if not skipna and name == "all":
706 nas = pc.invert(pc.is_null(self._pa_array))
707 arr = pc.and_kleene(nas, pc.not_equal(self._pa_array, ""))
708 else:
709 arr = pc.not_equal(self._pa_array, "")
710 return ArrowExtensionArray(arr)._reduce(
711 name, skipna=skipna, keepdims=keepdims, **kwargs
712 )
713 else:
714 return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
715
716 def insert(self, loc: int, item) -> ArrowStringArrayNumpySemantics:
717 if item is np.nan:
718 item = libmissing.NA
719 return super().insert(loc, item) # type: ignore[return-value]