Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pandas/core/computation/ops.py: 38%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

293 statements  

1""" 

2Operator classes for eval. 

3""" 

4 

5from __future__ import annotations 

6 

7from datetime import datetime 

8from functools import partial 

9import operator 

10from typing import ( 

11 Callable, 

12 Iterable, 

13 Iterator, 

14 Literal, 

15) 

16 

17import numpy as np 

18 

19from pandas._libs.tslibs import Timestamp 

20 

21from pandas.core.dtypes.common import ( 

22 is_list_like, 

23 is_scalar, 

24) 

25 

26import pandas.core.common as com 

27from pandas.core.computation.common import ( 

28 ensure_decoded, 

29 result_type_many, 

30) 

31from pandas.core.computation.scope import DEFAULT_GLOBALS 

32 

33from pandas.io.formats.printing import ( 

34 pprint_thing, 

35 pprint_thing_encoded, 

36) 

37 

38REDUCTIONS = ("sum", "prod", "min", "max") 

39 

40_unary_math_ops = ( 

41 "sin", 

42 "cos", 

43 "exp", 

44 "log", 

45 "expm1", 

46 "log1p", 

47 "sqrt", 

48 "sinh", 

49 "cosh", 

50 "tanh", 

51 "arcsin", 

52 "arccos", 

53 "arctan", 

54 "arccosh", 

55 "arcsinh", 

56 "arctanh", 

57 "abs", 

58 "log10", 

59 "floor", 

60 "ceil", 

61) 

62_binary_math_ops = ("arctan2",) 

63 

64MATHOPS = _unary_math_ops + _binary_math_ops 

65 

66 

67LOCAL_TAG = "__pd_eval_local_" 

68 

69 

70class Term: 

71 def __new__(cls, name, env, side=None, encoding=None): 

72 klass = Constant if not isinstance(name, str) else cls 

73 # error: Argument 2 for "super" not an instance of argument 1 

74 supr_new = super(Term, klass).__new__ # type: ignore[misc] 

75 return supr_new(klass) 

76 

77 is_local: bool 

78 

79 def __init__(self, name, env, side=None, encoding=None) -> None: 

80 # name is a str for Term, but may be something else for subclasses 

81 self._name = name 

82 self.env = env 

83 self.side = side 

84 tname = str(name) 

85 self.is_local = tname.startswith(LOCAL_TAG) or tname in DEFAULT_GLOBALS 

86 self._value = self._resolve_name() 

87 self.encoding = encoding 

88 

89 @property 

90 def local_name(self) -> str: 

91 return self.name.replace(LOCAL_TAG, "") 

92 

93 def __repr__(self) -> str: 

94 return pprint_thing(self.name) 

95 

96 def __call__(self, *args, **kwargs): 

97 return self.value 

98 

99 def evaluate(self, *args, **kwargs) -> Term: 

100 return self 

101 

102 def _resolve_name(self): 

103 local_name = str(self.local_name) 

104 is_local = self.is_local 

105 if local_name in self.env.scope and isinstance( 

106 self.env.scope[local_name], type 

107 ): 

108 is_local = False 

109 

110 res = self.env.resolve(local_name, is_local=is_local) 

111 self.update(res) 

112 

113 if hasattr(res, "ndim") and res.ndim > 2: 

114 raise NotImplementedError( 

115 "N-dimensional objects, where N > 2, are not supported with eval" 

116 ) 

117 return res 

118 

119 def update(self, value) -> None: 

120 """ 

121 search order for local (i.e., @variable) variables: 

122 

123 scope, key_variable 

124 [('locals', 'local_name'), 

125 ('globals', 'local_name'), 

126 ('locals', 'key'), 

127 ('globals', 'key')] 

128 """ 

129 key = self.name 

130 

131 # if it's a variable name (otherwise a constant) 

132 if isinstance(key, str): 

133 self.env.swapkey(self.local_name, key, new_value=value) 

134 

135 self.value = value 

136 

137 @property 

138 def is_scalar(self) -> bool: 

139 return is_scalar(self._value) 

140 

141 @property 

142 def type(self): 

143 try: 

144 # potentially very slow for large, mixed dtype frames 

145 return self._value.values.dtype 

146 except AttributeError: 

147 try: 

148 # ndarray 

149 return self._value.dtype 

150 except AttributeError: 

151 # scalar 

152 return type(self._value) 

153 

154 return_type = type 

155 

156 @property 

157 def raw(self) -> str: 

158 return f"{type(self).__name__}(name={repr(self.name)}, type={self.type})" 

159 

160 @property 

161 def is_datetime(self) -> bool: 

162 try: 

163 t = self.type.type 

164 except AttributeError: 

165 t = self.type 

166 

167 return issubclass(t, (datetime, np.datetime64)) 

168 

169 @property 

170 def value(self): 

171 return self._value 

172 

173 @value.setter 

174 def value(self, new_value) -> None: 

175 self._value = new_value 

176 

177 @property 

178 def name(self): 

179 return self._name 

180 

181 @property 

182 def ndim(self) -> int: 

183 return self._value.ndim 

184 

185 

186class Constant(Term): 

187 def __init__(self, value, env, side=None, encoding=None) -> None: 

188 super().__init__(value, env, side=side, encoding=encoding) 

189 

190 def _resolve_name(self): 

191 return self._name 

192 

193 @property 

194 def name(self): 

195 return self.value 

196 

197 def __repr__(self) -> str: 

198 # in python 2 str() of float 

199 # can truncate shorter than repr() 

200 return repr(self.name) 

201 

202 

203_bool_op_map = {"not": "~", "and": "&", "or": "|"} 

204 

205 

206class Op: 

207 """ 

208 Hold an operator of arbitrary arity. 

209 """ 

210 

211 op: str 

212 

213 def __init__(self, op: str, operands: Iterable[Term | Op], encoding=None) -> None: 

214 self.op = _bool_op_map.get(op, op) 

215 self.operands = operands 

216 self.encoding = encoding 

217 

218 def __iter__(self) -> Iterator: 

219 return iter(self.operands) 

220 

221 def __repr__(self) -> str: 

222 """ 

223 Print a generic n-ary operator and its operands using infix notation. 

224 """ 

225 # recurse over the operands 

226 parened = (f"({pprint_thing(opr)})" for opr in self.operands) 

227 return pprint_thing(f" {self.op} ".join(parened)) 

228 

229 @property 

230 def return_type(self): 

231 # clobber types to bool if the op is a boolean operator 

232 if self.op in (CMP_OPS_SYMS + BOOL_OPS_SYMS): 

233 return np.bool_ 

234 return result_type_many(*(term.type for term in com.flatten(self))) 

235 

236 @property 

237 def has_invalid_return_type(self) -> bool: 

238 types = self.operand_types 

239 obj_dtype_set = frozenset([np.dtype("object")]) 

240 return self.return_type == object and types - obj_dtype_set 

241 

242 @property 

243 def operand_types(self): 

244 return frozenset(term.type for term in com.flatten(self)) 

245 

246 @property 

247 def is_scalar(self) -> bool: 

248 return all(operand.is_scalar for operand in self.operands) 

249 

250 @property 

251 def is_datetime(self) -> bool: 

252 try: 

253 t = self.return_type.type 

254 except AttributeError: 

255 t = self.return_type 

256 

257 return issubclass(t, (datetime, np.datetime64)) 

258 

259 

260def _in(x, y): 

261 """ 

262 Compute the vectorized membership of ``x in y`` if possible, otherwise 

263 use Python. 

264 """ 

265 try: 

266 return x.isin(y) 

267 except AttributeError: 

268 if is_list_like(x): 

269 try: 

270 return y.isin(x) 

271 except AttributeError: 

272 pass 

273 return x in y 

274 

275 

276def _not_in(x, y): 

277 """ 

278 Compute the vectorized membership of ``x not in y`` if possible, 

279 otherwise use Python. 

280 """ 

281 try: 

282 return ~x.isin(y) 

283 except AttributeError: 

284 if is_list_like(x): 

285 try: 

286 return ~y.isin(x) 

287 except AttributeError: 

288 pass 

289 return x not in y 

290 

291 

292CMP_OPS_SYMS = (">", "<", ">=", "<=", "==", "!=", "in", "not in") 

293_cmp_ops_funcs = ( 

294 operator.gt, 

295 operator.lt, 

296 operator.ge, 

297 operator.le, 

298 operator.eq, 

299 operator.ne, 

300 _in, 

301 _not_in, 

302) 

303_cmp_ops_dict = dict(zip(CMP_OPS_SYMS, _cmp_ops_funcs)) 

304 

305BOOL_OPS_SYMS = ("&", "|", "and", "or") 

306_bool_ops_funcs = (operator.and_, operator.or_, operator.and_, operator.or_) 

307_bool_ops_dict = dict(zip(BOOL_OPS_SYMS, _bool_ops_funcs)) 

308 

309ARITH_OPS_SYMS = ("+", "-", "*", "/", "**", "//", "%") 

310_arith_ops_funcs = ( 

311 operator.add, 

312 operator.sub, 

313 operator.mul, 

314 operator.truediv, 

315 operator.pow, 

316 operator.floordiv, 

317 operator.mod, 

318) 

319_arith_ops_dict = dict(zip(ARITH_OPS_SYMS, _arith_ops_funcs)) 

320 

321SPECIAL_CASE_ARITH_OPS_SYMS = ("**", "//", "%") 

322_special_case_arith_ops_funcs = (operator.pow, operator.floordiv, operator.mod) 

323_special_case_arith_ops_dict = dict( 

324 zip(SPECIAL_CASE_ARITH_OPS_SYMS, _special_case_arith_ops_funcs) 

325) 

326 

327_binary_ops_dict = {} 

328 

329for d in (_cmp_ops_dict, _bool_ops_dict, _arith_ops_dict): 

330 _binary_ops_dict.update(d) 

331 

332 

333def _cast_inplace(terms, acceptable_dtypes, dtype) -> None: 

334 """ 

335 Cast an expression inplace. 

336 

337 Parameters 

338 ---------- 

339 terms : Op 

340 The expression that should cast. 

341 acceptable_dtypes : list of acceptable numpy.dtype 

342 Will not cast if term's dtype in this list. 

343 dtype : str or numpy.dtype 

344 The dtype to cast to. 

345 """ 

346 dt = np.dtype(dtype) 

347 for term in terms: 

348 if term.type in acceptable_dtypes: 

349 continue 

350 

351 try: 

352 new_value = term.value.astype(dt) 

353 except AttributeError: 

354 new_value = dt.type(term.value) 

355 term.update(new_value) 

356 

357 

358def is_term(obj) -> bool: 

359 return isinstance(obj, Term) 

360 

361 

362class BinOp(Op): 

363 """ 

364 Hold a binary operator and its operands. 

365 

366 Parameters 

367 ---------- 

368 op : str 

369 lhs : Term or Op 

370 rhs : Term or Op 

371 """ 

372 

373 def __init__(self, op: str, lhs, rhs) -> None: 

374 super().__init__(op, (lhs, rhs)) 

375 self.lhs = lhs 

376 self.rhs = rhs 

377 

378 self._disallow_scalar_only_bool_ops() 

379 

380 self.convert_values() 

381 

382 try: 

383 self.func = _binary_ops_dict[op] 

384 except KeyError as err: 

385 # has to be made a list for python3 

386 keys = list(_binary_ops_dict.keys()) 

387 raise ValueError( 

388 f"Invalid binary operator {repr(op)}, valid operators are {keys}" 

389 ) from err 

390 

391 def __call__(self, env): 

392 """ 

393 Recursively evaluate an expression in Python space. 

394 

395 Parameters 

396 ---------- 

397 env : Scope 

398 

399 Returns 

400 ------- 

401 object 

402 The result of an evaluated expression. 

403 """ 

404 # recurse over the left/right nodes 

405 left = self.lhs(env) 

406 right = self.rhs(env) 

407 

408 return self.func(left, right) 

409 

410 def evaluate(self, env, engine: str, parser, term_type, eval_in_python): 

411 """ 

412 Evaluate a binary operation *before* being passed to the engine. 

413 

414 Parameters 

415 ---------- 

416 env : Scope 

417 engine : str 

418 parser : str 

419 term_type : type 

420 eval_in_python : list 

421 

422 Returns 

423 ------- 

424 term_type 

425 The "pre-evaluated" expression as an instance of ``term_type`` 

426 """ 

427 if engine == "python": 

428 res = self(env) 

429 else: 

430 # recurse over the left/right nodes 

431 

432 left = self.lhs.evaluate( 

433 env, 

434 engine=engine, 

435 parser=parser, 

436 term_type=term_type, 

437 eval_in_python=eval_in_python, 

438 ) 

439 

440 right = self.rhs.evaluate( 

441 env, 

442 engine=engine, 

443 parser=parser, 

444 term_type=term_type, 

445 eval_in_python=eval_in_python, 

446 ) 

447 

448 # base cases 

449 if self.op in eval_in_python: 

450 res = self.func(left.value, right.value) 

451 else: 

452 from pandas.core.computation.eval import eval 

453 

454 res = eval(self, local_dict=env, engine=engine, parser=parser) 

455 

456 name = env.add_tmp(res) 

457 return term_type(name, env=env) 

458 

459 def convert_values(self) -> None: 

460 """ 

461 Convert datetimes to a comparable value in an expression. 

462 """ 

463 

464 def stringify(value): 

465 encoder: Callable 

466 if self.encoding is not None: 

467 encoder = partial(pprint_thing_encoded, encoding=self.encoding) 

468 else: 

469 encoder = pprint_thing 

470 return encoder(value) 

471 

472 lhs, rhs = self.lhs, self.rhs 

473 

474 if is_term(lhs) and lhs.is_datetime and is_term(rhs) and rhs.is_scalar: 

475 v = rhs.value 

476 if isinstance(v, (int, float)): 

477 v = stringify(v) 

478 v = Timestamp(ensure_decoded(v)) 

479 if v.tz is not None: 

480 v = v.tz_convert("UTC") 

481 self.rhs.update(v) 

482 

483 if is_term(rhs) and rhs.is_datetime and is_term(lhs) and lhs.is_scalar: 

484 v = lhs.value 

485 if isinstance(v, (int, float)): 

486 v = stringify(v) 

487 v = Timestamp(ensure_decoded(v)) 

488 if v.tz is not None: 

489 v = v.tz_convert("UTC") 

490 self.lhs.update(v) 

491 

492 def _disallow_scalar_only_bool_ops(self): 

493 rhs = self.rhs 

494 lhs = self.lhs 

495 

496 # GH#24883 unwrap dtype if necessary to ensure we have a type object 

497 rhs_rt = rhs.return_type 

498 rhs_rt = getattr(rhs_rt, "type", rhs_rt) 

499 lhs_rt = lhs.return_type 

500 lhs_rt = getattr(lhs_rt, "type", lhs_rt) 

501 if ( 

502 (lhs.is_scalar or rhs.is_scalar) 

503 and self.op in _bool_ops_dict 

504 and ( 

505 not ( 

506 issubclass(rhs_rt, (bool, np.bool_)) 

507 and issubclass(lhs_rt, (bool, np.bool_)) 

508 ) 

509 ) 

510 ): 

511 raise NotImplementedError("cannot evaluate scalar only bool ops") 

512 

513 

514def isnumeric(dtype) -> bool: 

515 return issubclass(np.dtype(dtype).type, np.number) 

516 

517 

518class Div(BinOp): 

519 """ 

520 Div operator to special case casting. 

521 

522 Parameters 

523 ---------- 

524 lhs, rhs : Term or Op 

525 The Terms or Ops in the ``/`` expression. 

526 """ 

527 

528 def __init__(self, lhs, rhs) -> None: 

529 super().__init__("/", lhs, rhs) 

530 

531 if not isnumeric(lhs.return_type) or not isnumeric(rhs.return_type): 

532 raise TypeError( 

533 f"unsupported operand type(s) for {self.op}: " 

534 f"'{lhs.return_type}' and '{rhs.return_type}'" 

535 ) 

536 

537 # do not upcast float32s to float64 un-necessarily 

538 acceptable_dtypes = [np.float32, np.float_] 

539 _cast_inplace(com.flatten(self), acceptable_dtypes, np.float_) 

540 

541 

542UNARY_OPS_SYMS = ("+", "-", "~", "not") 

543_unary_ops_funcs = (operator.pos, operator.neg, operator.invert, operator.invert) 

544_unary_ops_dict = dict(zip(UNARY_OPS_SYMS, _unary_ops_funcs)) 

545 

546 

547class UnaryOp(Op): 

548 """ 

549 Hold a unary operator and its operands. 

550 

551 Parameters 

552 ---------- 

553 op : str 

554 The token used to represent the operator. 

555 operand : Term or Op 

556 The Term or Op operand to the operator. 

557 

558 Raises 

559 ------ 

560 ValueError 

561 * If no function associated with the passed operator token is found. 

562 """ 

563 

564 def __init__(self, op: Literal["+", "-", "~", "not"], operand) -> None: 

565 super().__init__(op, (operand,)) 

566 self.operand = operand 

567 

568 try: 

569 self.func = _unary_ops_dict[op] 

570 except KeyError as err: 

571 raise ValueError( 

572 f"Invalid unary operator {repr(op)}, " 

573 f"valid operators are {UNARY_OPS_SYMS}" 

574 ) from err 

575 

576 def __call__(self, env) -> MathCall: 

577 operand = self.operand(env) 

578 # error: Cannot call function of unknown type 

579 return self.func(operand) # type: ignore[operator] 

580 

581 def __repr__(self) -> str: 

582 return pprint_thing(f"{self.op}({self.operand})") 

583 

584 @property 

585 def return_type(self) -> np.dtype: 

586 operand = self.operand 

587 if operand.return_type == np.dtype("bool"): 

588 return np.dtype("bool") 

589 if isinstance(operand, Op) and ( 

590 operand.op in _cmp_ops_dict or operand.op in _bool_ops_dict 

591 ): 

592 return np.dtype("bool") 

593 return np.dtype("int") 

594 

595 

596class MathCall(Op): 

597 def __init__(self, func, args) -> None: 

598 super().__init__(func.name, args) 

599 self.func = func 

600 

601 def __call__(self, env): 

602 # error: "Op" not callable 

603 operands = [op(env) for op in self.operands] # type: ignore[operator] 

604 with np.errstate(all="ignore"): 

605 return self.func.func(*operands) 

606 

607 def __repr__(self) -> str: 

608 operands = map(str, self.operands) 

609 return pprint_thing(f"{self.op}({','.join(operands)})") 

610 

611 

612class FuncNode: 

613 def __init__(self, name: str) -> None: 

614 if name not in MATHOPS: 

615 raise ValueError(f'"{name}" is not a supported function') 

616 self.name = name 

617 self.func = getattr(np, name) 

618 

619 def __call__(self, *args): 

620 return MathCall(self, args)