1""" manage PyTables query interface via Expressions """
2from __future__ import annotations
3
4import ast
5from decimal import (
6 Decimal,
7 InvalidOperation,
8)
9from functools import partial
10from typing import (
11 TYPE_CHECKING,
12 Any,
13 ClassVar,
14)
15
16import numpy as np
17
18from pandas._libs.tslibs import (
19 Timedelta,
20 Timestamp,
21)
22from pandas.errors import UndefinedVariableError
23
24from pandas.core.dtypes.common import is_list_like
25
26import pandas.core.common as com
27from pandas.core.computation import (
28 expr,
29 ops,
30 scope as _scope,
31)
32from pandas.core.computation.common import ensure_decoded
33from pandas.core.computation.expr import BaseExprVisitor
34from pandas.core.computation.ops import is_term
35from pandas.core.construction import extract_array
36from pandas.core.indexes.base import Index
37
38from pandas.io.formats.printing import (
39 pprint_thing,
40 pprint_thing_encoded,
41)
42
43if TYPE_CHECKING:
44 from pandas._typing import (
45 Self,
46 npt,
47 )
48
49
50class PyTablesScope(_scope.Scope):
51 __slots__ = ("queryables",)
52
53 queryables: dict[str, Any]
54
55 def __init__(
56 self,
57 level: int,
58 global_dict=None,
59 local_dict=None,
60 queryables: dict[str, Any] | None = None,
61 ) -> None:
62 super().__init__(level + 1, global_dict=global_dict, local_dict=local_dict)
63 self.queryables = queryables or {}
64
65
66class Term(ops.Term):
67 env: PyTablesScope
68
69 def __new__(cls, name, env, side=None, encoding=None):
70 if isinstance(name, str):
71 klass = cls
72 else:
73 klass = Constant
74 return object.__new__(klass)
75
76 def __init__(self, name, env: PyTablesScope, side=None, encoding=None) -> None:
77 super().__init__(name, env, side=side, encoding=encoding)
78
79 def _resolve_name(self):
80 # must be a queryables
81 if self.side == "left":
82 # Note: The behavior of __new__ ensures that self.name is a str here
83 if self.name not in self.env.queryables:
84 raise NameError(f"name {repr(self.name)} is not defined")
85 return self.name
86
87 # resolve the rhs (and allow it to be None)
88 try:
89 return self.env.resolve(self.name, is_local=False)
90 except UndefinedVariableError:
91 return self.name
92
93 # read-only property overwriting read/write property
94 @property # type: ignore[misc]
95 def value(self):
96 return self._value
97
98
99class Constant(Term):
100 def __init__(self, name, env: PyTablesScope, side=None, encoding=None) -> None:
101 assert isinstance(env, PyTablesScope), type(env)
102 super().__init__(name, env, side=side, encoding=encoding)
103
104 def _resolve_name(self):
105 return self._name
106
107
108class BinOp(ops.BinOp):
109 _max_selectors = 31
110
111 op: str
112 queryables: dict[str, Any]
113 condition: str | None
114
115 def __init__(self, op: str, lhs, rhs, queryables: dict[str, Any], encoding) -> None:
116 super().__init__(op, lhs, rhs)
117 self.queryables = queryables
118 self.encoding = encoding
119 self.condition = None
120
121 def _disallow_scalar_only_bool_ops(self) -> None:
122 pass
123
124 def prune(self, klass):
125 def pr(left, right):
126 """create and return a new specialized BinOp from myself"""
127 if left is None:
128 return right
129 elif right is None:
130 return left
131
132 k = klass
133 if isinstance(left, ConditionBinOp):
134 if isinstance(right, ConditionBinOp):
135 k = JointConditionBinOp
136 elif isinstance(left, k):
137 return left
138 elif isinstance(right, k):
139 return right
140
141 elif isinstance(left, FilterBinOp):
142 if isinstance(right, FilterBinOp):
143 k = JointFilterBinOp
144 elif isinstance(left, k):
145 return left
146 elif isinstance(right, k):
147 return right
148
149 return k(
150 self.op, left, right, queryables=self.queryables, encoding=self.encoding
151 ).evaluate()
152
153 left, right = self.lhs, self.rhs
154
155 if is_term(left) and is_term(right):
156 res = pr(left.value, right.value)
157 elif not is_term(left) and is_term(right):
158 res = pr(left.prune(klass), right.value)
159 elif is_term(left) and not is_term(right):
160 res = pr(left.value, right.prune(klass))
161 elif not (is_term(left) or is_term(right)):
162 res = pr(left.prune(klass), right.prune(klass))
163
164 return res
165
166 def conform(self, rhs):
167 """inplace conform rhs"""
168 if not is_list_like(rhs):
169 rhs = [rhs]
170 if isinstance(rhs, np.ndarray):
171 rhs = rhs.ravel()
172 return rhs
173
174 @property
175 def is_valid(self) -> bool:
176 """return True if this is a valid field"""
177 return self.lhs in self.queryables
178
179 @property
180 def is_in_table(self) -> bool:
181 """
182 return True if this is a valid column name for generation (e.g. an
183 actual column in the table)
184 """
185 return self.queryables.get(self.lhs) is not None
186
187 @property
188 def kind(self):
189 """the kind of my field"""
190 return getattr(self.queryables.get(self.lhs), "kind", None)
191
192 @property
193 def meta(self):
194 """the meta of my field"""
195 return getattr(self.queryables.get(self.lhs), "meta", None)
196
197 @property
198 def metadata(self):
199 """the metadata of my field"""
200 return getattr(self.queryables.get(self.lhs), "metadata", None)
201
202 def generate(self, v) -> str:
203 """create and return the op string for this TermValue"""
204 val = v.tostring(self.encoding)
205 return f"({self.lhs} {self.op} {val})"
206
207 def convert_value(self, v) -> TermValue:
208 """
209 convert the expression that is in the term to something that is
210 accepted by pytables
211 """
212
213 def stringify(value):
214 if self.encoding is not None:
215 return pprint_thing_encoded(value, encoding=self.encoding)
216 return pprint_thing(value)
217
218 kind = ensure_decoded(self.kind)
219 meta = ensure_decoded(self.meta)
220 if kind == "datetime" or (kind and kind.startswith("datetime64")):
221 if isinstance(v, (int, float)):
222 v = stringify(v)
223 v = ensure_decoded(v)
224 v = Timestamp(v).as_unit("ns")
225 if v.tz is not None:
226 v = v.tz_convert("UTC")
227 return TermValue(v, v._value, kind)
228 elif kind in ("timedelta64", "timedelta"):
229 if isinstance(v, str):
230 v = Timedelta(v)
231 else:
232 v = Timedelta(v, unit="s")
233 v = v.as_unit("ns")._value
234 return TermValue(int(v), v, kind)
235 elif meta == "category":
236 metadata = extract_array(self.metadata, extract_numpy=True)
237 result: npt.NDArray[np.intp] | np.intp | int
238 if v not in metadata:
239 result = -1
240 else:
241 result = metadata.searchsorted(v, side="left")
242 return TermValue(result, result, "integer")
243 elif kind == "integer":
244 try:
245 v_dec = Decimal(v)
246 except InvalidOperation:
247 # GH 54186
248 # convert v to float to raise float's ValueError
249 float(v)
250 else:
251 v = int(v_dec.to_integral_exact(rounding="ROUND_HALF_EVEN"))
252 return TermValue(v, v, kind)
253 elif kind == "float":
254 v = float(v)
255 return TermValue(v, v, kind)
256 elif kind == "bool":
257 if isinstance(v, str):
258 v = v.strip().lower() not in [
259 "false",
260 "f",
261 "no",
262 "n",
263 "none",
264 "0",
265 "[]",
266 "{}",
267 "",
268 ]
269 else:
270 v = bool(v)
271 return TermValue(v, v, kind)
272 elif isinstance(v, str):
273 # string quoting
274 return TermValue(v, stringify(v), "string")
275 else:
276 raise TypeError(f"Cannot compare {v} of type {type(v)} to {kind} column")
277
278 def convert_values(self) -> None:
279 pass
280
281
282class FilterBinOp(BinOp):
283 filter: tuple[Any, Any, Index] | None = None
284
285 def __repr__(self) -> str:
286 if self.filter is None:
287 return "Filter: Not Initialized"
288 return pprint_thing(f"[Filter : [{self.filter[0]}] -> [{self.filter[1]}]")
289
290 def invert(self) -> Self:
291 """invert the filter"""
292 if self.filter is not None:
293 self.filter = (
294 self.filter[0],
295 self.generate_filter_op(invert=True),
296 self.filter[2],
297 )
298 return self
299
300 def format(self):
301 """return the actual filter format"""
302 return [self.filter]
303
304 # error: Signature of "evaluate" incompatible with supertype "BinOp"
305 def evaluate(self) -> Self | None: # type: ignore[override]
306 if not self.is_valid:
307 raise ValueError(f"query term is not valid [{self}]")
308
309 rhs = self.conform(self.rhs)
310 values = list(rhs)
311
312 if self.is_in_table:
313 # if too many values to create the expression, use a filter instead
314 if self.op in ["==", "!="] and len(values) > self._max_selectors:
315 filter_op = self.generate_filter_op()
316 self.filter = (self.lhs, filter_op, Index(values))
317
318 return self
319 return None
320
321 # equality conditions
322 if self.op in ["==", "!="]:
323 filter_op = self.generate_filter_op()
324 self.filter = (self.lhs, filter_op, Index(values))
325
326 else:
327 raise TypeError(
328 f"passing a filterable condition to a non-table indexer [{self}]"
329 )
330
331 return self
332
333 def generate_filter_op(self, invert: bool = False):
334 if (self.op == "!=" and not invert) or (self.op == "==" and invert):
335 return lambda axis, vals: ~axis.isin(vals)
336 else:
337 return lambda axis, vals: axis.isin(vals)
338
339
340class JointFilterBinOp(FilterBinOp):
341 def format(self):
342 raise NotImplementedError("unable to collapse Joint Filters")
343
344 # error: Signature of "evaluate" incompatible with supertype "BinOp"
345 def evaluate(self) -> Self: # type: ignore[override]
346 return self
347
348
349class ConditionBinOp(BinOp):
350 def __repr__(self) -> str:
351 return pprint_thing(f"[Condition : [{self.condition}]]")
352
353 def invert(self):
354 """invert the condition"""
355 # if self.condition is not None:
356 # self.condition = "~(%s)" % self.condition
357 # return self
358 raise NotImplementedError(
359 "cannot use an invert condition when passing to numexpr"
360 )
361
362 def format(self):
363 """return the actual ne format"""
364 return self.condition
365
366 # error: Signature of "evaluate" incompatible with supertype "BinOp"
367 def evaluate(self) -> Self | None: # type: ignore[override]
368 if not self.is_valid:
369 raise ValueError(f"query term is not valid [{self}]")
370
371 # convert values if we are in the table
372 if not self.is_in_table:
373 return None
374
375 rhs = self.conform(self.rhs)
376 values = [self.convert_value(v) for v in rhs]
377
378 # equality conditions
379 if self.op in ["==", "!="]:
380 # too many values to create the expression?
381 if len(values) <= self._max_selectors:
382 vs = [self.generate(v) for v in values]
383 self.condition = f"({' | '.join(vs)})"
384
385 # use a filter after reading
386 else:
387 return None
388 else:
389 self.condition = self.generate(values[0])
390
391 return self
392
393
394class JointConditionBinOp(ConditionBinOp):
395 # error: Signature of "evaluate" incompatible with supertype "BinOp"
396 def evaluate(self) -> Self: # type: ignore[override]
397 self.condition = f"({self.lhs.condition} {self.op} {self.rhs.condition})"
398 return self
399
400
401class UnaryOp(ops.UnaryOp):
402 def prune(self, klass):
403 if self.op != "~":
404 raise NotImplementedError("UnaryOp only support invert type ops")
405
406 operand = self.operand
407 operand = operand.prune(klass)
408
409 if operand is not None and (
410 issubclass(klass, ConditionBinOp)
411 and operand.condition is not None
412 or not issubclass(klass, ConditionBinOp)
413 and issubclass(klass, FilterBinOp)
414 and operand.filter is not None
415 ):
416 return operand.invert()
417 return None
418
419
420class PyTablesExprVisitor(BaseExprVisitor):
421 const_type: ClassVar[type[ops.Term]] = Constant
422 term_type: ClassVar[type[Term]] = Term
423
424 def __init__(self, env, engine, parser, **kwargs) -> None:
425 super().__init__(env, engine, parser)
426 for bin_op in self.binary_ops:
427 bin_node = self.binary_op_nodes_map[bin_op]
428 setattr(
429 self,
430 f"visit_{bin_node}",
431 lambda node, bin_op=bin_op: partial(BinOp, bin_op, **kwargs),
432 )
433
434 def visit_UnaryOp(self, node, **kwargs) -> ops.Term | UnaryOp | None:
435 if isinstance(node.op, (ast.Not, ast.Invert)):
436 return UnaryOp("~", self.visit(node.operand))
437 elif isinstance(node.op, ast.USub):
438 return self.const_type(-self.visit(node.operand).value, self.env)
439 elif isinstance(node.op, ast.UAdd):
440 raise NotImplementedError("Unary addition not supported")
441 # TODO: return None might never be reached
442 return None
443
444 def visit_Index(self, node, **kwargs):
445 return self.visit(node.value).value
446
447 def visit_Assign(self, node, **kwargs):
448 cmpr = ast.Compare(
449 ops=[ast.Eq()], left=node.targets[0], comparators=[node.value]
450 )
451 return self.visit(cmpr)
452
453 def visit_Subscript(self, node, **kwargs) -> ops.Term:
454 # only allow simple subscripts
455
456 value = self.visit(node.value)
457 slobj = self.visit(node.slice)
458 try:
459 value = value.value
460 except AttributeError:
461 pass
462
463 if isinstance(slobj, Term):
464 # In py39 np.ndarray lookups with Term containing int raise
465 slobj = slobj.value
466
467 try:
468 return self.const_type(value[slobj], self.env)
469 except TypeError as err:
470 raise ValueError(
471 f"cannot subscript {repr(value)} with {repr(slobj)}"
472 ) from err
473
474 def visit_Attribute(self, node, **kwargs):
475 attr = node.attr
476 value = node.value
477
478 ctx = type(node.ctx)
479 if ctx == ast.Load:
480 # resolve the value
481 resolved = self.visit(value)
482
483 # try to get the value to see if we are another expression
484 try:
485 resolved = resolved.value
486 except AttributeError:
487 pass
488
489 try:
490 return self.term_type(getattr(resolved, attr), self.env)
491 except AttributeError:
492 # something like datetime.datetime where scope is overridden
493 if isinstance(value, ast.Name) and value.id == attr:
494 return resolved
495
496 raise ValueError(f"Invalid Attribute context {ctx.__name__}")
497
498 def translate_In(self, op):
499 return ast.Eq() if isinstance(op, ast.In) else op
500
501 def _rewrite_membership_op(self, node, left, right):
502 return self.visit(node.op), node.op, left, right
503
504
505def _validate_where(w):
506 """
507 Validate that the where statement is of the right type.
508
509 The type may either be String, Expr, or list-like of Exprs.
510
511 Parameters
512 ----------
513 w : String term expression, Expr, or list-like of Exprs.
514
515 Returns
516 -------
517 where : The original where clause if the check was successful.
518
519 Raises
520 ------
521 TypeError : An invalid data type was passed in for w (e.g. dict).
522 """
523 if not (isinstance(w, (PyTablesExpr, str)) or is_list_like(w)):
524 raise TypeError(
525 "where must be passed as a string, PyTablesExpr, "
526 "or list-like of PyTablesExpr"
527 )
528
529 return w
530
531
532class PyTablesExpr(expr.Expr):
533 """
534 Hold a pytables-like expression, comprised of possibly multiple 'terms'.
535
536 Parameters
537 ----------
538 where : string term expression, PyTablesExpr, or list-like of PyTablesExprs
539 queryables : a "kinds" map (dict of column name -> kind), or None if column
540 is non-indexable
541 encoding : an encoding that will encode the query terms
542
543 Returns
544 -------
545 a PyTablesExpr object
546
547 Examples
548 --------
549 'index>=date'
550 "columns=['A', 'D']"
551 'columns=A'
552 'columns==A'
553 "~(columns=['A','B'])"
554 'index>df.index[3] & string="bar"'
555 '(index>df.index[3] & index<=df.index[6]) | string="bar"'
556 "ts>=Timestamp('2012-02-01')"
557 "major_axis>=20130101"
558 """
559
560 _visitor: PyTablesExprVisitor | None
561 env: PyTablesScope
562 expr: str
563
564 def __init__(
565 self,
566 where,
567 queryables: dict[str, Any] | None = None,
568 encoding=None,
569 scope_level: int = 0,
570 ) -> None:
571 where = _validate_where(where)
572
573 self.encoding = encoding
574 self.condition = None
575 self.filter = None
576 self.terms = None
577 self._visitor = None
578
579 # capture the environment if needed
580 local_dict: _scope.DeepChainMap[Any, Any] | None = None
581
582 if isinstance(where, PyTablesExpr):
583 local_dict = where.env.scope
584 _where = where.expr
585
586 elif is_list_like(where):
587 where = list(where)
588 for idx, w in enumerate(where):
589 if isinstance(w, PyTablesExpr):
590 local_dict = w.env.scope
591 else:
592 where[idx] = _validate_where(w)
593 _where = " & ".join([f"({w})" for w in com.flatten(where)])
594 else:
595 # _validate_where ensures we otherwise have a string
596 _where = where
597
598 self.expr = _where
599 self.env = PyTablesScope(scope_level + 1, local_dict=local_dict)
600
601 if queryables is not None and isinstance(self.expr, str):
602 self.env.queryables.update(queryables)
603 self._visitor = PyTablesExprVisitor(
604 self.env,
605 queryables=queryables,
606 parser="pytables",
607 engine="pytables",
608 encoding=encoding,
609 )
610 self.terms = self.parse()
611
612 def __repr__(self) -> str:
613 if self.terms is not None:
614 return pprint_thing(self.terms)
615 return pprint_thing(self.expr)
616
617 def evaluate(self):
618 """create and return the numexpr condition and filter"""
619 try:
620 self.condition = self.terms.prune(ConditionBinOp)
621 except AttributeError as err:
622 raise ValueError(
623 f"cannot process expression [{self.expr}], [{self}] "
624 "is not a valid condition"
625 ) from err
626 try:
627 self.filter = self.terms.prune(FilterBinOp)
628 except AttributeError as err:
629 raise ValueError(
630 f"cannot process expression [{self.expr}], [{self}] "
631 "is not a valid filter"
632 ) from err
633
634 return self.condition, self.filter
635
636
637class TermValue:
638 """hold a term value the we use to construct a condition/filter"""
639
640 def __init__(self, value, converted, kind: str) -> None:
641 assert isinstance(kind, str), kind
642 self.value = value
643 self.converted = converted
644 self.kind = kind
645
646 def tostring(self, encoding) -> str:
647 """quote the string if not encoded else encode and return"""
648 if self.kind == "string":
649 if encoding is not None:
650 return str(self.converted)
651 return f'"{self.converted}"'
652 elif self.kind == "float":
653 # python 2 str(float) is not always
654 # round-trippable so use repr()
655 return repr(self.converted)
656 return str(self.converted)
657
658
659def maybe_expression(s) -> bool:
660 """loose checking if s is a pytables-acceptable expression"""
661 if not isinstance(s, str):
662 return False
663 operations = PyTablesExprVisitor.binary_ops + PyTablesExprVisitor.unary_ops + ("=",)
664
665 # make sure we have an op at least
666 return any(op in s for op in operations)