1""" manage PyTables query interface via Expressions """
2from __future__ import annotations
3
4import ast
5from functools import partial
6from typing import Any
7
8import numpy as np
9
10from pandas._libs.tslibs import (
11 Timedelta,
12 Timestamp,
13)
14from pandas._typing import npt
15from pandas.errors import UndefinedVariableError
16
17from pandas.core.dtypes.common import is_list_like
18
19import pandas.core.common as com
20from pandas.core.computation import (
21 expr,
22 ops,
23 scope as _scope,
24)
25from pandas.core.computation.common import ensure_decoded
26from pandas.core.computation.expr import BaseExprVisitor
27from pandas.core.computation.ops import is_term
28from pandas.core.construction import extract_array
29from pandas.core.indexes.base import Index
30
31from pandas.io.formats.printing import (
32 pprint_thing,
33 pprint_thing_encoded,
34)
35
36
37class PyTablesScope(_scope.Scope):
38 __slots__ = ("queryables",)
39
40 queryables: dict[str, Any]
41
42 def __init__(
43 self,
44 level: int,
45 global_dict=None,
46 local_dict=None,
47 queryables: dict[str, Any] | None = None,
48 ) -> None:
49 super().__init__(level + 1, global_dict=global_dict, local_dict=local_dict)
50 self.queryables = queryables or {}
51
52
53class Term(ops.Term):
54 env: PyTablesScope
55
56 def __new__(cls, name, env, side=None, encoding=None):
57 if isinstance(name, str):
58 klass = cls
59 else:
60 klass = Constant
61 return object.__new__(klass)
62
63 def __init__(self, name, env: PyTablesScope, side=None, encoding=None) -> None:
64 super().__init__(name, env, side=side, encoding=encoding)
65
66 def _resolve_name(self):
67 # must be a queryables
68 if self.side == "left":
69 # Note: The behavior of __new__ ensures that self.name is a str here
70 if self.name not in self.env.queryables:
71 raise NameError(f"name {repr(self.name)} is not defined")
72 return self.name
73
74 # resolve the rhs (and allow it to be None)
75 try:
76 return self.env.resolve(self.name, is_local=False)
77 except UndefinedVariableError:
78 return self.name
79
80 # read-only property overwriting read/write property
81 @property # type: ignore[misc]
82 def value(self):
83 return self._value
84
85
86class Constant(Term):
87 def __init__(self, value, env: PyTablesScope, side=None, encoding=None) -> None:
88 assert isinstance(env, PyTablesScope), type(env)
89 super().__init__(value, env, side=side, encoding=encoding)
90
91 def _resolve_name(self):
92 return self._name
93
94
95class BinOp(ops.BinOp):
96 _max_selectors = 31
97
98 op: str
99 queryables: dict[str, Any]
100 condition: str | None
101
102 def __init__(self, op: str, lhs, rhs, queryables: dict[str, Any], encoding) -> None:
103 super().__init__(op, lhs, rhs)
104 self.queryables = queryables
105 self.encoding = encoding
106 self.condition = None
107
108 def _disallow_scalar_only_bool_ops(self) -> None:
109 pass
110
111 def prune(self, klass):
112 def pr(left, right):
113 """create and return a new specialized BinOp from myself"""
114 if left is None:
115 return right
116 elif right is None:
117 return left
118
119 k = klass
120 if isinstance(left, ConditionBinOp):
121 if isinstance(right, ConditionBinOp):
122 k = JointConditionBinOp
123 elif isinstance(left, k):
124 return left
125 elif isinstance(right, k):
126 return right
127
128 elif isinstance(left, FilterBinOp):
129 if isinstance(right, FilterBinOp):
130 k = JointFilterBinOp
131 elif isinstance(left, k):
132 return left
133 elif isinstance(right, k):
134 return right
135
136 return k(
137 self.op, left, right, queryables=self.queryables, encoding=self.encoding
138 ).evaluate()
139
140 left, right = self.lhs, self.rhs
141
142 if is_term(left) and is_term(right):
143 res = pr(left.value, right.value)
144 elif not is_term(left) and is_term(right):
145 res = pr(left.prune(klass), right.value)
146 elif is_term(left) and not is_term(right):
147 res = pr(left.value, right.prune(klass))
148 elif not (is_term(left) or is_term(right)):
149 res = pr(left.prune(klass), right.prune(klass))
150
151 return res
152
153 def conform(self, rhs):
154 """inplace conform rhs"""
155 if not is_list_like(rhs):
156 rhs = [rhs]
157 if isinstance(rhs, np.ndarray):
158 rhs = rhs.ravel()
159 return rhs
160
161 @property
162 def is_valid(self) -> bool:
163 """return True if this is a valid field"""
164 return self.lhs in self.queryables
165
166 @property
167 def is_in_table(self) -> bool:
168 """
169 return True if this is a valid column name for generation (e.g. an
170 actual column in the table)
171 """
172 return self.queryables.get(self.lhs) is not None
173
174 @property
175 def kind(self):
176 """the kind of my field"""
177 return getattr(self.queryables.get(self.lhs), "kind", None)
178
179 @property
180 def meta(self):
181 """the meta of my field"""
182 return getattr(self.queryables.get(self.lhs), "meta", None)
183
184 @property
185 def metadata(self):
186 """the metadata of my field"""
187 return getattr(self.queryables.get(self.lhs), "metadata", None)
188
189 def generate(self, v) -> str:
190 """create and return the op string for this TermValue"""
191 val = v.tostring(self.encoding)
192 return f"({self.lhs} {self.op} {val})"
193
194 def convert_value(self, v) -> TermValue:
195 """
196 convert the expression that is in the term to something that is
197 accepted by pytables
198 """
199
200 def stringify(value):
201 if self.encoding is not None:
202 return pprint_thing_encoded(value, encoding=self.encoding)
203 return pprint_thing(value)
204
205 kind = ensure_decoded(self.kind)
206 meta = ensure_decoded(self.meta)
207 if kind in ("datetime64", "datetime"):
208 if isinstance(v, (int, float)):
209 v = stringify(v)
210 v = ensure_decoded(v)
211 v = Timestamp(v).as_unit("ns")
212 if v.tz is not None:
213 v = v.tz_convert("UTC")
214 return TermValue(v, v._value, kind)
215 elif kind in ("timedelta64", "timedelta"):
216 if isinstance(v, str):
217 v = Timedelta(v)
218 else:
219 v = Timedelta(v, unit="s")
220 v = v.as_unit("ns")._value
221 return TermValue(int(v), v, kind)
222 elif meta == "category":
223 metadata = extract_array(self.metadata, extract_numpy=True)
224 result: npt.NDArray[np.intp] | np.intp | int
225 if v not in metadata:
226 result = -1
227 else:
228 result = metadata.searchsorted(v, side="left")
229 return TermValue(result, result, "integer")
230 elif kind == "integer":
231 v = int(float(v))
232 return TermValue(v, v, kind)
233 elif kind == "float":
234 v = float(v)
235 return TermValue(v, v, kind)
236 elif kind == "bool":
237 if isinstance(v, str):
238 v = v.strip().lower() not in [
239 "false",
240 "f",
241 "no",
242 "n",
243 "none",
244 "0",
245 "[]",
246 "{}",
247 "",
248 ]
249 else:
250 v = bool(v)
251 return TermValue(v, v, kind)
252 elif isinstance(v, str):
253 # string quoting
254 return TermValue(v, stringify(v), "string")
255 else:
256 raise TypeError(f"Cannot compare {v} of type {type(v)} to {kind} column")
257
258 def convert_values(self) -> None:
259 pass
260
261
262class FilterBinOp(BinOp):
263 filter: tuple[Any, Any, Index] | None = None
264
265 def __repr__(self) -> str:
266 if self.filter is None:
267 return "Filter: Not Initialized"
268 return pprint_thing(f"[Filter : [{self.filter[0]}] -> [{self.filter[1]}]")
269
270 def invert(self):
271 """invert the filter"""
272 if self.filter is not None:
273 self.filter = (
274 self.filter[0],
275 self.generate_filter_op(invert=True),
276 self.filter[2],
277 )
278 return self
279
280 def format(self):
281 """return the actual filter format"""
282 return [self.filter]
283
284 def evaluate(self):
285 if not self.is_valid:
286 raise ValueError(f"query term is not valid [{self}]")
287
288 rhs = self.conform(self.rhs)
289 values = list(rhs)
290
291 if self.is_in_table:
292 # if too many values to create the expression, use a filter instead
293 if self.op in ["==", "!="] and len(values) > self._max_selectors:
294 filter_op = self.generate_filter_op()
295 self.filter = (self.lhs, filter_op, Index(values))
296
297 return self
298 return None
299
300 # equality conditions
301 if self.op in ["==", "!="]:
302 filter_op = self.generate_filter_op()
303 self.filter = (self.lhs, filter_op, Index(values))
304
305 else:
306 raise TypeError(
307 f"passing a filterable condition to a non-table indexer [{self}]"
308 )
309
310 return self
311
312 def generate_filter_op(self, invert: bool = False):
313 if (self.op == "!=" and not invert) or (self.op == "==" and invert):
314 return lambda axis, vals: ~axis.isin(vals)
315 else:
316 return lambda axis, vals: axis.isin(vals)
317
318
319class JointFilterBinOp(FilterBinOp):
320 def format(self):
321 raise NotImplementedError("unable to collapse Joint Filters")
322
323 def evaluate(self):
324 return self
325
326
327class ConditionBinOp(BinOp):
328 def __repr__(self) -> str:
329 return pprint_thing(f"[Condition : [{self.condition}]]")
330
331 def invert(self):
332 """invert the condition"""
333 # if self.condition is not None:
334 # self.condition = "~(%s)" % self.condition
335 # return self
336 raise NotImplementedError(
337 "cannot use an invert condition when passing to numexpr"
338 )
339
340 def format(self):
341 """return the actual ne format"""
342 return self.condition
343
344 def evaluate(self):
345 if not self.is_valid:
346 raise ValueError(f"query term is not valid [{self}]")
347
348 # convert values if we are in the table
349 if not self.is_in_table:
350 return None
351
352 rhs = self.conform(self.rhs)
353 values = [self.convert_value(v) for v in rhs]
354
355 # equality conditions
356 if self.op in ["==", "!="]:
357 # too many values to create the expression?
358 if len(values) <= self._max_selectors:
359 vs = [self.generate(v) for v in values]
360 self.condition = f"({' | '.join(vs)})"
361
362 # use a filter after reading
363 else:
364 return None
365 else:
366 self.condition = self.generate(values[0])
367
368 return self
369
370
371class JointConditionBinOp(ConditionBinOp):
372 def evaluate(self):
373 self.condition = f"({self.lhs.condition} {self.op} {self.rhs.condition})"
374 return self
375
376
377class UnaryOp(ops.UnaryOp):
378 def prune(self, klass):
379 if self.op != "~":
380 raise NotImplementedError("UnaryOp only support invert type ops")
381
382 operand = self.operand
383 operand = operand.prune(klass)
384
385 if operand is not None and (
386 issubclass(klass, ConditionBinOp)
387 and operand.condition is not None
388 or not issubclass(klass, ConditionBinOp)
389 and issubclass(klass, FilterBinOp)
390 and operand.filter is not None
391 ):
392 return operand.invert()
393 return None
394
395
396class PyTablesExprVisitor(BaseExprVisitor):
397 const_type = Constant
398 term_type = Term
399
400 def __init__(self, env, engine, parser, **kwargs) -> None:
401 super().__init__(env, engine, parser)
402 for bin_op in self.binary_ops:
403 bin_node = self.binary_op_nodes_map[bin_op]
404 setattr(
405 self,
406 f"visit_{bin_node}",
407 lambda node, bin_op=bin_op: partial(BinOp, bin_op, **kwargs),
408 )
409
410 def visit_UnaryOp(self, node, **kwargs):
411 if isinstance(node.op, (ast.Not, ast.Invert)):
412 return UnaryOp("~", self.visit(node.operand))
413 elif isinstance(node.op, ast.USub):
414 return self.const_type(-self.visit(node.operand).value, self.env)
415 elif isinstance(node.op, ast.UAdd):
416 raise NotImplementedError("Unary addition not supported")
417
418 def visit_Index(self, node, **kwargs):
419 return self.visit(node.value).value
420
421 def visit_Assign(self, node, **kwargs):
422 cmpr = ast.Compare(
423 ops=[ast.Eq()], left=node.targets[0], comparators=[node.value]
424 )
425 return self.visit(cmpr)
426
427 def visit_Subscript(self, node, **kwargs):
428 # only allow simple subscripts
429
430 value = self.visit(node.value)
431 slobj = self.visit(node.slice)
432 try:
433 value = value.value
434 except AttributeError:
435 pass
436
437 if isinstance(slobj, Term):
438 # In py39 np.ndarray lookups with Term containing int raise
439 slobj = slobj.value
440
441 try:
442 return self.const_type(value[slobj], self.env)
443 except TypeError as err:
444 raise ValueError(
445 f"cannot subscript {repr(value)} with {repr(slobj)}"
446 ) from err
447
448 def visit_Attribute(self, node, **kwargs):
449 attr = node.attr
450 value = node.value
451
452 ctx = type(node.ctx)
453 if ctx == ast.Load:
454 # resolve the value
455 resolved = self.visit(value)
456
457 # try to get the value to see if we are another expression
458 try:
459 resolved = resolved.value
460 except AttributeError:
461 pass
462
463 try:
464 return self.term_type(getattr(resolved, attr), self.env)
465 except AttributeError:
466 # something like datetime.datetime where scope is overridden
467 if isinstance(value, ast.Name) and value.id == attr:
468 return resolved
469
470 raise ValueError(f"Invalid Attribute context {ctx.__name__}")
471
472 def translate_In(self, op):
473 return ast.Eq() if isinstance(op, ast.In) else op
474
475 def _rewrite_membership_op(self, node, left, right):
476 return self.visit(node.op), node.op, left, right
477
478
479def _validate_where(w):
480 """
481 Validate that the where statement is of the right type.
482
483 The type may either be String, Expr, or list-like of Exprs.
484
485 Parameters
486 ----------
487 w : String term expression, Expr, or list-like of Exprs.
488
489 Returns
490 -------
491 where : The original where clause if the check was successful.
492
493 Raises
494 ------
495 TypeError : An invalid data type was passed in for w (e.g. dict).
496 """
497 if not (isinstance(w, (PyTablesExpr, str)) or is_list_like(w)):
498 raise TypeError(
499 "where must be passed as a string, PyTablesExpr, "
500 "or list-like of PyTablesExpr"
501 )
502
503 return w
504
505
506class PyTablesExpr(expr.Expr):
507 """
508 Hold a pytables-like expression, comprised of possibly multiple 'terms'.
509
510 Parameters
511 ----------
512 where : string term expression, PyTablesExpr, or list-like of PyTablesExprs
513 queryables : a "kinds" map (dict of column name -> kind), or None if column
514 is non-indexable
515 encoding : an encoding that will encode the query terms
516
517 Returns
518 -------
519 a PyTablesExpr object
520
521 Examples
522 --------
523 'index>=date'
524 "columns=['A', 'D']"
525 'columns=A'
526 'columns==A'
527 "~(columns=['A','B'])"
528 'index>df.index[3] & string="bar"'
529 '(index>df.index[3] & index<=df.index[6]) | string="bar"'
530 "ts>=Timestamp('2012-02-01')"
531 "major_axis>=20130101"
532 """
533
534 _visitor: PyTablesExprVisitor | None
535 env: PyTablesScope
536 expr: str
537
538 def __init__(
539 self,
540 where,
541 queryables: dict[str, Any] | None = None,
542 encoding=None,
543 scope_level: int = 0,
544 ) -> None:
545 where = _validate_where(where)
546
547 self.encoding = encoding
548 self.condition = None
549 self.filter = None
550 self.terms = None
551 self._visitor = None
552
553 # capture the environment if needed
554 local_dict: _scope.DeepChainMap[Any, Any] | None = None
555
556 if isinstance(where, PyTablesExpr):
557 local_dict = where.env.scope
558 _where = where.expr
559
560 elif is_list_like(where):
561 where = list(where)
562 for idx, w in enumerate(where):
563 if isinstance(w, PyTablesExpr):
564 local_dict = w.env.scope
565 else:
566 w = _validate_where(w)
567 where[idx] = w
568 _where = " & ".join([f"({w})" for w in com.flatten(where)])
569 else:
570 # _validate_where ensures we otherwise have a string
571 _where = where
572
573 self.expr = _where
574 self.env = PyTablesScope(scope_level + 1, local_dict=local_dict)
575
576 if queryables is not None and isinstance(self.expr, str):
577 self.env.queryables.update(queryables)
578 self._visitor = PyTablesExprVisitor(
579 self.env,
580 queryables=queryables,
581 parser="pytables",
582 engine="pytables",
583 encoding=encoding,
584 )
585 self.terms = self.parse()
586
587 def __repr__(self) -> str:
588 if self.terms is not None:
589 return pprint_thing(self.terms)
590 return pprint_thing(self.expr)
591
592 def evaluate(self):
593 """create and return the numexpr condition and filter"""
594 try:
595 self.condition = self.terms.prune(ConditionBinOp)
596 except AttributeError as err:
597 raise ValueError(
598 f"cannot process expression [{self.expr}], [{self}] "
599 "is not a valid condition"
600 ) from err
601 try:
602 self.filter = self.terms.prune(FilterBinOp)
603 except AttributeError as err:
604 raise ValueError(
605 f"cannot process expression [{self.expr}], [{self}] "
606 "is not a valid filter"
607 ) from err
608
609 return self.condition, self.filter
610
611
612class TermValue:
613 """hold a term value the we use to construct a condition/filter"""
614
615 def __init__(self, value, converted, kind: str) -> None:
616 assert isinstance(kind, str), kind
617 self.value = value
618 self.converted = converted
619 self.kind = kind
620
621 def tostring(self, encoding) -> str:
622 """quote the string if not encoded else encode and return"""
623 if self.kind == "string":
624 if encoding is not None:
625 return str(self.converted)
626 return f'"{self.converted}"'
627 elif self.kind == "float":
628 # python 2 str(float) is not always
629 # round-trippable so use repr()
630 return repr(self.converted)
631 return str(self.converted)
632
633
634def maybe_expression(s) -> bool:
635 """loose checking if s is a pytables-acceptable expression"""
636 if not isinstance(s, str):
637 return False
638 operations = PyTablesExprVisitor.binary_ops + PyTablesExprVisitor.unary_ops + ("=",)
639
640 # make sure we have an op at least
641 return any(op in s for op in operations)