1"""
2MIT License
3
4Copyright (c) 2021 Alex Hall
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23"""
24
25import __future__
26import ast
27import dis
28import inspect
29import io
30import linecache
31import re
32import sys
33import types
34from collections import defaultdict
35from copy import deepcopy
36from functools import lru_cache
37from itertools import islice
38from itertools import zip_longest
39from operator import attrgetter
40from pathlib import Path
41from threading import RLock
42from tokenize import detect_encoding
43from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Sized, Tuple, Type, TypeVar, Union, cast
44from ._utils import mangled_name,assert_, EnhancedAST,EnhancedInstruction,Instruction,get_instructions
45
46if TYPE_CHECKING: # pragma: no cover
47 from asttokens import ASTTokens, ASTText
48 from asttokens.asttokens import ASTTextBase
49
50
51function_node_types = (ast.FunctionDef, ast.AsyncFunctionDef) # type: Tuple[Type, ...]
52
53cache = lru_cache(maxsize=None)
54
55TESTING = 0
56
57class NotOneValueFound(Exception):
58 def __init__(self,msg,values=[]):
59 # type: (str, Sequence) -> None
60 self.values=values
61 super(NotOneValueFound,self).__init__(msg)
62
63T = TypeVar('T')
64
65
66def only(it):
67 # type: (Iterable[T]) -> T
68 if isinstance(it, Sized):
69 if len(it) != 1:
70 raise NotOneValueFound('Expected one value, found %s' % len(it))
71 # noinspection PyTypeChecker
72 return list(it)[0]
73
74 lst = tuple(islice(it, 2))
75 if len(lst) == 0:
76 raise NotOneValueFound('Expected one value, found 0')
77 if len(lst) > 1:
78 raise NotOneValueFound('Expected one value, found several',lst)
79 return lst[0]
80
81
82class Source(object):
83 """
84 The source code of a single file and associated metadata.
85
86 The main method of interest is the classmethod `executing(frame)`.
87
88 If you want an instance of this class, don't construct it.
89 Ideally use the classmethod `for_frame(frame)`.
90 If you don't have a frame, use `for_filename(filename [, module_globals])`.
91 These methods cache instances by filename, so at most one instance exists per filename.
92
93 Attributes:
94 - filename
95 - text
96 - lines
97 - tree: AST parsed from text, or None if text is not valid Python
98 All nodes in the tree have an extra `parent` attribute
99
100 Other methods of interest:
101 - statements_at_line
102 - asttokens
103 - code_qualname
104 """
105
106 def __init__(self, filename, lines):
107 # type: (str, Sequence[str]) -> None
108 """
109 Don't call this constructor, see the class docstring.
110 """
111
112 self.filename = filename
113 self.text = ''.join(lines)
114 self.lines = [line.rstrip('\r\n') for line in lines]
115
116 self._nodes_by_line = defaultdict(list)
117 self.tree = None
118 self._qualnames = {}
119 self._asttokens = None # type: Optional[ASTTokens]
120 self._asttext = None # type: Optional[ASTText]
121
122 try:
123 self.tree = ast.parse(self.text, filename=filename)
124 except (SyntaxError, ValueError):
125 pass
126 else:
127 for node in ast.walk(self.tree):
128 for child in ast.iter_child_nodes(node):
129 cast(EnhancedAST, child).parent = cast(EnhancedAST, node)
130 for lineno in node_linenos(node):
131 self._nodes_by_line[lineno].append(node)
132
133 visitor = QualnameVisitor()
134 visitor.visit(self.tree)
135 self._qualnames = visitor.qualnames
136
137 @classmethod
138 def for_frame(cls, frame, use_cache=True):
139 # type: (types.FrameType, bool) -> "Source"
140 """
141 Returns the `Source` object corresponding to the file the frame is executing in.
142 """
143 return cls.for_filename(frame.f_code.co_filename, frame.f_globals or {}, use_cache)
144
145 @classmethod
146 def for_filename(
147 cls,
148 filename,
149 module_globals=None,
150 use_cache=True, # noqa no longer used
151 ):
152 # type: (Union[str, Path], Optional[Dict[str, Any]], bool) -> "Source"
153 if isinstance(filename, Path):
154 filename = str(filename)
155
156 def get_lines():
157 # type: () -> List[str]
158 return linecache.getlines(cast(str, filename), module_globals)
159
160 # Save the current linecache entry, then ensure the cache is up to date.
161 entry = linecache.cache.get(filename) # type: ignore[attr-defined]
162 linecache.checkcache(filename)
163 lines = get_lines()
164 if entry is not None and not lines:
165 # There was an entry, checkcache removed it, and nothing replaced it.
166 # This means the file wasn't simply changed (because the `lines` wouldn't be empty)
167 # but rather the file was found not to exist, probably because `filename` was fake.
168 # Restore the original entry so that we still have something.
169 linecache.cache[filename] = entry # type: ignore[attr-defined]
170 lines = get_lines()
171
172 return cls._for_filename_and_lines(filename, tuple(lines))
173
174 @classmethod
175 def _for_filename_and_lines(cls, filename, lines):
176 # type: (str, Sequence[str]) -> "Source"
177 source_cache = cls._class_local('__source_cache_with_lines', {}) # type: Dict[Tuple[str, Sequence[str]], Source]
178 try:
179 return source_cache[(filename, lines)]
180 except KeyError:
181 pass
182
183 result = source_cache[(filename, lines)] = cls(filename, lines)
184 return result
185
186 @classmethod
187 def lazycache(cls, frame):
188 # type: (types.FrameType) -> None
189 linecache.lazycache(frame.f_code.co_filename, frame.f_globals)
190
191 @classmethod
192 def executing(cls, frame_or_tb):
193 # type: (Union[types.TracebackType, types.FrameType]) -> "Executing"
194 """
195 Returns an `Executing` object representing the operation
196 currently executing in the given frame or traceback object.
197 """
198 if isinstance(frame_or_tb, types.TracebackType):
199 # https://docs.python.org/3/reference/datamodel.html#traceback-objects
200 # "tb_lineno gives the line number where the exception occurred;
201 # tb_lasti indicates the precise instruction.
202 # The line number and last instruction in the traceback may differ
203 # from the line number of its frame object
204 # if the exception occurred in a try statement with no matching except clause
205 # or with a finally clause."
206 tb = frame_or_tb
207 frame = tb.tb_frame
208 lineno = tb.tb_lineno
209 lasti = tb.tb_lasti
210 else:
211 frame = frame_or_tb
212 lineno = frame.f_lineno
213 lasti = frame.f_lasti
214
215
216
217 code = frame.f_code
218 key = (code, id(code), lasti)
219 executing_cache = cls._class_local('__executing_cache', {}) # type: Dict[Tuple[types.CodeType, int, int], Any]
220
221 args = executing_cache.get(key)
222 if not args:
223 node = stmts = decorator = None
224 source = cls.for_frame(frame)
225 tree = source.tree
226 if tree:
227 try:
228 stmts = source.statements_at_line(lineno)
229 if stmts:
230 if is_ipython_cell_code(code):
231 decorator, node = find_node_ipython(frame, lasti, stmts, source)
232 else:
233 node_finder = NodeFinder(frame, stmts, tree, lasti, source)
234 node = node_finder.result
235 decorator = node_finder.decorator
236
237 if node:
238 new_stmts = {statement_containing_node(node)}
239 assert_(new_stmts <= stmts)
240 stmts = new_stmts
241 except Exception:
242 if TESTING:
243 raise
244
245 executing_cache[key] = args = source, node, stmts, decorator
246
247 return Executing(frame, *args)
248
249 @classmethod
250 def _class_local(cls, name, default):
251 # type: (str, T) -> T
252 """
253 Returns an attribute directly associated with this class
254 (as opposed to subclasses), setting default if necessary
255 """
256 # classes have a mappingproxy preventing us from using setdefault
257 result = cls.__dict__.get(name, default)
258 setattr(cls, name, result)
259 return result
260
261 @cache
262 def statements_at_line(self, lineno):
263 # type: (int) -> Set[EnhancedAST]
264 """
265 Returns the statement nodes overlapping the given line.
266
267 Returns at most one statement unless semicolons are present.
268
269 If the `text` attribute is not valid python, meaning
270 `tree` is None, returns an empty set.
271
272 Otherwise, `Source.for_frame(frame).statements_at_line(frame.f_lineno)`
273 should return at least one statement.
274 """
275
276 return {
277 statement_containing_node(node)
278 for node in
279 self._nodes_by_line[lineno]
280 }
281
282 def asttext(self):
283 # type: () -> ASTText
284 """
285 Returns an ASTText object for getting the source of specific AST nodes.
286
287 See http://asttokens.readthedocs.io/en/latest/api-index.html
288 """
289 from asttokens import ASTText # must be installed separately
290
291 if self._asttext is None:
292 self._asttext = ASTText(self.text, tree=self.tree, filename=self.filename)
293
294 return self._asttext
295
296 def asttokens(self):
297 # type: () -> ASTTokens
298 """
299 Returns an ASTTokens object for getting the source of specific AST nodes.
300
301 See http://asttokens.readthedocs.io/en/latest/api-index.html
302 """
303 import asttokens # must be installed separately
304
305 if self._asttokens is None:
306 if hasattr(asttokens, 'ASTText'):
307 self._asttokens = self.asttext().asttokens
308 else: # pragma: no cover
309 self._asttokens = asttokens.ASTTokens(self.text, tree=self.tree, filename=self.filename)
310 return self._asttokens
311
312 def _asttext_base(self):
313 # type: () -> ASTTextBase
314 import asttokens # must be installed separately
315
316 if hasattr(asttokens, 'ASTText'):
317 return self.asttext()
318 else: # pragma: no cover
319 return self.asttokens()
320
321 @staticmethod
322 def decode_source(source):
323 # type: (Union[str, bytes]) -> str
324 if isinstance(source, bytes):
325 encoding = Source.detect_encoding(source)
326 return source.decode(encoding)
327 else:
328 return source
329
330 @staticmethod
331 def detect_encoding(source):
332 # type: (bytes) -> str
333 return detect_encoding(io.BytesIO(source).readline)[0]
334
335 def code_qualname(self, code):
336 # type: (types.CodeType) -> str
337 """
338 Imitates the __qualname__ attribute of functions for code objects.
339 Given:
340
341 - A function `func`
342 - A frame `frame` for an execution of `func`, meaning:
343 `frame.f_code is func.__code__`
344
345 `Source.for_frame(frame).code_qualname(frame.f_code)`
346 will be equal to `func.__qualname__`*. Works for Python 2 as well,
347 where of course no `__qualname__` attribute exists.
348
349 Falls back to `code.co_name` if there is no appropriate qualname.
350
351 Based on https://github.com/wbolster/qualname
352
353 (* unless `func` is a lambda
354 nested inside another lambda on the same line, in which case
355 the outer lambda's qualname will be returned for the codes
356 of both lambdas)
357 """
358 assert_(code.co_filename == self.filename)
359 return self._qualnames.get((code.co_name, code.co_firstlineno), code.co_name)
360
361
362class Executing(object):
363 """
364 Information about the operation a frame is currently executing.
365
366 Generally you will just want `node`, which is the AST node being executed,
367 or None if it's unknown.
368
369 If a decorator is currently being called, then:
370 - `node` is a function or class definition
371 - `decorator` is the expression in `node.decorator_list` being called
372 - `statements == {node}`
373 """
374
375 def __init__(self, frame, source, node, stmts, decorator):
376 # type: (types.FrameType, Source, EnhancedAST, Set[ast.stmt], Optional[EnhancedAST]) -> None
377 self.frame = frame
378 self.source = source
379 self.node = node
380 self.statements = stmts
381 self.decorator = decorator
382
383 def code_qualname(self):
384 # type: () -> str
385 return self.source.code_qualname(self.frame.f_code)
386
387 def text(self):
388 # type: () -> str
389 return self.source._asttext_base().get_text(self.node)
390
391 def text_range(self):
392 # type: () -> Tuple[int, int]
393 return self.source._asttext_base().get_text_range(self.node)
394
395
396class QualnameVisitor(ast.NodeVisitor):
397 def __init__(self):
398 # type: () -> None
399 super(QualnameVisitor, self).__init__()
400 self.stack = [] # type: List[str]
401 self.qualnames = {} # type: Dict[Tuple[str, int], str]
402
403 def add_qualname(self, node, name=None):
404 # type: (ast.AST, Optional[str]) -> None
405 name = name or node.name # type: ignore[attr-defined]
406 self.stack.append(name)
407 if getattr(node, 'decorator_list', ()):
408 lineno = node.decorator_list[0].lineno # type: ignore[attr-defined]
409 else:
410 lineno = node.lineno # type: ignore[attr-defined]
411 self.qualnames.setdefault((name, lineno), ".".join(self.stack))
412
413 def visit_FunctionDef(self, node, name=None):
414 # type: (ast.AST, Optional[str]) -> None
415 assert isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)), node
416 self.add_qualname(node, name)
417 self.stack.append('<locals>')
418 children = [] # type: Sequence[ast.AST]
419 if isinstance(node, ast.Lambda):
420 children = [node.body]
421 else:
422 children = node.body
423 for child in children:
424 self.visit(child)
425 self.stack.pop()
426 self.stack.pop()
427
428 # Find lambdas in the function definition outside the body,
429 # e.g. decorators or default arguments
430 # Based on iter_child_nodes
431 for field, child in ast.iter_fields(node):
432 if field == 'body':
433 continue
434 if isinstance(child, ast.AST):
435 self.visit(child)
436 elif isinstance(child, list):
437 for grandchild in child:
438 if isinstance(grandchild, ast.AST):
439 self.visit(grandchild)
440
441 visit_AsyncFunctionDef = visit_FunctionDef
442
443 def visit_Lambda(self, node):
444 # type: (ast.AST) -> None
445 assert isinstance(node, ast.Lambda)
446 self.visit_FunctionDef(node, '<lambda>')
447
448 def visit_ClassDef(self, node):
449 # type: (ast.AST) -> None
450 assert isinstance(node, ast.ClassDef)
451 self.add_qualname(node)
452 self.generic_visit(node)
453 self.stack.pop()
454
455
456
457
458
459future_flags = sum(
460 getattr(__future__, fname).compiler_flag for fname in __future__.all_feature_names
461)
462
463
464def compile_similar_to(source, matching_code):
465 # type: (ast.Module, types.CodeType) -> Any
466 return compile(
467 source,
468 matching_code.co_filename,
469 'exec',
470 flags=future_flags & matching_code.co_flags,
471 dont_inherit=True,
472 )
473
474
475sentinel = 'io8urthglkjdghvljusketgIYRFYUVGHFRTBGVHKGF78678957647698'
476
477def is_rewritten_by_pytest(code):
478 # type: (types.CodeType) -> bool
479 return any(
480 bc.opname != "LOAD_CONST" and isinstance(bc.argval,str) and bc.argval.startswith("@py")
481 for bc in get_instructions(code)
482 )
483
484
485class SentinelNodeFinder(object):
486 result = None # type: EnhancedAST
487
488 def __init__(self, frame, stmts, tree, lasti, source):
489 # type: (types.FrameType, Set[EnhancedAST], ast.Module, int, Source) -> None
490 assert_(stmts)
491 self.frame = frame
492 self.tree = tree
493 self.code = code = frame.f_code
494 self.is_pytest = is_rewritten_by_pytest(code)
495
496 if self.is_pytest:
497 self.ignore_linenos = frozenset(assert_linenos(tree))
498 else:
499 self.ignore_linenos = frozenset()
500
501 self.decorator = None
502
503 self.instruction = instruction = self.get_actual_current_instruction(lasti)
504 op_name = instruction.opname
505 extra_filter = lambda e: True # type: Callable[[Any], bool]
506 ctx = type(None) # type: Type
507
508 typ = type(None) # type: Type
509 if op_name.startswith('CALL_'):
510 typ = ast.Call
511 elif op_name.startswith(('BINARY_SUBSCR', 'SLICE+')):
512 typ = ast.Subscript
513 ctx = ast.Load
514 elif op_name.startswith('BINARY_'):
515 typ = ast.BinOp
516 op_type = dict(
517 BINARY_POWER=ast.Pow,
518 BINARY_MULTIPLY=ast.Mult,
519 BINARY_MATRIX_MULTIPLY=getattr(ast, "MatMult", ()),
520 BINARY_FLOOR_DIVIDE=ast.FloorDiv,
521 BINARY_TRUE_DIVIDE=ast.Div,
522 BINARY_MODULO=ast.Mod,
523 BINARY_ADD=ast.Add,
524 BINARY_SUBTRACT=ast.Sub,
525 BINARY_LSHIFT=ast.LShift,
526 BINARY_RSHIFT=ast.RShift,
527 BINARY_AND=ast.BitAnd,
528 BINARY_XOR=ast.BitXor,
529 BINARY_OR=ast.BitOr,
530 )[op_name]
531 extra_filter = lambda e: isinstance(e.op, op_type)
532 elif op_name.startswith('UNARY_'):
533 typ = ast.UnaryOp
534 op_type = dict(
535 UNARY_POSITIVE=ast.UAdd,
536 UNARY_NEGATIVE=ast.USub,
537 UNARY_NOT=ast.Not,
538 UNARY_INVERT=ast.Invert,
539 )[op_name]
540 extra_filter = lambda e: isinstance(e.op, op_type)
541 elif op_name in ('LOAD_ATTR', 'LOAD_METHOD', 'LOOKUP_METHOD'):
542 typ = ast.Attribute
543 ctx = ast.Load
544 extra_filter = lambda e:mangled_name(e) == instruction.argval
545 elif op_name in ('LOAD_NAME', 'LOAD_GLOBAL', 'LOAD_FAST', 'LOAD_DEREF', 'LOAD_CLASSDEREF'):
546 typ = ast.Name
547 ctx = ast.Load
548 extra_filter = lambda e:mangled_name(e) == instruction.argval
549 elif op_name in ('COMPARE_OP', 'IS_OP', 'CONTAINS_OP'):
550 typ = ast.Compare
551 extra_filter = lambda e: len(e.ops) == 1
552 elif op_name.startswith(('STORE_SLICE', 'STORE_SUBSCR')):
553 ctx = ast.Store
554 typ = ast.Subscript
555 elif op_name.startswith('STORE_ATTR'):
556 ctx = ast.Store
557 typ = ast.Attribute
558 extra_filter = lambda e:mangled_name(e) == instruction.argval
559 else:
560 raise RuntimeError(op_name)
561
562
563 with lock:
564 exprs = {
565 cast(EnhancedAST, node)
566 for stmt in stmts
567 for node in ast.walk(stmt)
568 if isinstance(node, typ)
569 if isinstance(getattr(node, "ctx", None), ctx)
570 if extra_filter(node)
571 if statement_containing_node(node) == stmt
572 }
573
574 if ctx == ast.Store:
575 # No special bytecode tricks here.
576 # We can handle multiple assigned attributes with different names,
577 # but only one assigned subscript.
578 self.result = only(exprs)
579 return
580
581 matching = list(self.matching_nodes(exprs))
582 if not matching and typ == ast.Call:
583 self.find_decorator(stmts)
584 else:
585 self.result = only(matching)
586
587 def find_decorator(self, stmts):
588 # type: (Union[List[EnhancedAST], Set[EnhancedAST]]) -> None
589 stmt = only(stmts)
590 assert_(isinstance(stmt, (ast.ClassDef, function_node_types)))
591 decorators = stmt.decorator_list # type: ignore[attr-defined]
592 assert_(decorators)
593 line_instructions = [
594 inst
595 for inst in self.clean_instructions(self.code)
596 if inst.lineno == self.frame.f_lineno
597 ]
598 last_decorator_instruction_index = [
599 i
600 for i, inst in enumerate(line_instructions)
601 if inst.opname == "CALL_FUNCTION"
602 ][-1]
603 assert_(
604 line_instructions[last_decorator_instruction_index + 1].opname.startswith(
605 "STORE_"
606 )
607 )
608 decorator_instructions = line_instructions[
609 last_decorator_instruction_index
610 - len(decorators)
611 + 1 : last_decorator_instruction_index
612 + 1
613 ]
614 assert_({inst.opname for inst in decorator_instructions} == {"CALL_FUNCTION"})
615 decorator_index = decorator_instructions.index(self.instruction)
616 decorator = decorators[::-1][decorator_index]
617 self.decorator = decorator
618 self.result = stmt
619
620 def clean_instructions(self, code):
621 # type: (types.CodeType) -> List[EnhancedInstruction]
622 return [
623 inst
624 for inst in get_instructions(code)
625 if inst.opname not in ("EXTENDED_ARG", "NOP")
626 if inst.lineno not in self.ignore_linenos
627 ]
628
629 def get_original_clean_instructions(self):
630 # type: () -> List[EnhancedInstruction]
631 result = self.clean_instructions(self.code)
632
633 # pypy sometimes (when is not clear)
634 # inserts JUMP_IF_NOT_DEBUG instructions in bytecode
635 # If they're not present in our compiled instructions,
636 # ignore them in the original bytecode
637 if not any(
638 inst.opname == "JUMP_IF_NOT_DEBUG"
639 for inst in self.compile_instructions()
640 ):
641 result = [
642 inst for inst in result
643 if inst.opname != "JUMP_IF_NOT_DEBUG"
644 ]
645
646 return result
647
648 def matching_nodes(self, exprs):
649 # type: (Set[EnhancedAST]) -> Iterator[EnhancedAST]
650 original_instructions = self.get_original_clean_instructions()
651 original_index = only(
652 i
653 for i, inst in enumerate(original_instructions)
654 if inst == self.instruction
655 )
656 for expr_index, expr in enumerate(exprs):
657 setter = get_setter(expr)
658 assert setter is not None
659 # noinspection PyArgumentList
660 replacement = ast.BinOp(
661 left=expr,
662 op=ast.Pow(),
663 right=ast.Str(s=sentinel),
664 )
665 ast.fix_missing_locations(replacement)
666 setter(replacement)
667 try:
668 instructions = self.compile_instructions()
669 finally:
670 setter(expr)
671
672 if sys.version_info >= (3, 10):
673 try:
674 handle_jumps(instructions, original_instructions)
675 except Exception:
676 # Give other candidates a chance
677 if TESTING or expr_index < len(exprs) - 1:
678 continue
679 raise
680
681 indices = [
682 i
683 for i, instruction in enumerate(instructions)
684 if instruction.argval == sentinel
685 ]
686
687 # There can be several indices when the bytecode is duplicated,
688 # as happens in a finally block in 3.9+
689 # First we remove the opcodes caused by our modifications
690 for index_num, sentinel_index in enumerate(indices):
691 # Adjustment for removing sentinel instructions below
692 # in past iterations
693 sentinel_index -= index_num * 2
694
695 assert_(instructions.pop(sentinel_index).opname == 'LOAD_CONST')
696 assert_(instructions.pop(sentinel_index).opname == 'BINARY_POWER')
697
698 # Then we see if any of the instruction indices match
699 for index_num, sentinel_index in enumerate(indices):
700 sentinel_index -= index_num * 2
701 new_index = sentinel_index - 1
702
703 if new_index != original_index:
704 continue
705
706 original_inst = original_instructions[original_index]
707 new_inst = instructions[new_index]
708
709 # In Python 3.9+, changing 'not x in y' to 'not sentinel_transformation(x in y)'
710 # changes a CONTAINS_OP(invert=1) to CONTAINS_OP(invert=0),<sentinel stuff>,UNARY_NOT
711 if (
712 original_inst.opname == new_inst.opname in ('CONTAINS_OP', 'IS_OP')
713 and original_inst.arg != new_inst.arg # type: ignore[attr-defined]
714 and (
715 original_instructions[original_index + 1].opname
716 != instructions[new_index + 1].opname == 'UNARY_NOT'
717 )):
718 # Remove the difference for the upcoming assert
719 instructions.pop(new_index + 1)
720
721 # Check that the modified instructions don't have anything unexpected
722 # 3.10 is a bit too weird to assert this in all cases but things still work
723 if sys.version_info < (3, 10):
724 for inst1, inst2 in zip_longest(
725 original_instructions, instructions
726 ):
727 assert_(inst1 and inst2 and opnames_match(inst1, inst2))
728
729 yield expr
730
731 def compile_instructions(self):
732 # type: () -> List[EnhancedInstruction]
733 module_code = compile_similar_to(self.tree, self.code)
734 code = only(self.find_codes(module_code))
735 return self.clean_instructions(code)
736
737 def find_codes(self, root_code):
738 # type: (types.CodeType) -> list
739 checks = [
740 attrgetter('co_firstlineno'),
741 attrgetter('co_freevars'),
742 attrgetter('co_cellvars'),
743 lambda c: is_ipython_cell_code_name(c.co_name) or c.co_name,
744 ] # type: List[Callable]
745 if not self.is_pytest:
746 checks += [
747 attrgetter('co_names'),
748 attrgetter('co_varnames'),
749 ]
750
751 def matches(c):
752 # type: (types.CodeType) -> bool
753 return all(
754 f(c) == f(self.code)
755 for f in checks
756 )
757
758 code_options = []
759 if matches(root_code):
760 code_options.append(root_code)
761
762 def finder(code):
763 # type: (types.CodeType) -> None
764 for const in code.co_consts:
765 if not inspect.iscode(const):
766 continue
767
768 if matches(const):
769 code_options.append(const)
770 finder(const)
771
772 finder(root_code)
773 return code_options
774
775 def get_actual_current_instruction(self, lasti):
776 # type: (int) -> EnhancedInstruction
777 """
778 Get the instruction corresponding to the current
779 frame offset, skipping EXTENDED_ARG instructions
780 """
781 # Don't use get_original_clean_instructions
782 # because we need the actual instructions including
783 # EXTENDED_ARG
784 instructions = list(get_instructions(self.code))
785 index = only(
786 i
787 for i, inst in enumerate(instructions)
788 if inst.offset == lasti
789 )
790
791 while True:
792 instruction = instructions[index]
793 if instruction.opname != "EXTENDED_ARG":
794 return instruction
795 index += 1
796
797
798
799def non_sentinel_instructions(instructions, start):
800 # type: (List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction]]
801 """
802 Yields (index, instruction) pairs excluding the basic
803 instructions introduced by the sentinel transformation
804 """
805 skip_power = False
806 for i, inst in islice(enumerate(instructions), start, None):
807 if inst.argval == sentinel:
808 assert_(inst.opname == "LOAD_CONST")
809 skip_power = True
810 continue
811 elif skip_power:
812 assert_(inst.opname == "BINARY_POWER")
813 skip_power = False
814 continue
815 yield i, inst
816
817
818def walk_both_instructions(original_instructions, original_start, instructions, start):
819 # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction, int, EnhancedInstruction]]
820 """
821 Yields matching indices and instructions from the new and original instructions,
822 leaving out changes made by the sentinel transformation.
823 """
824 original_iter = islice(enumerate(original_instructions), original_start, None)
825 new_iter = non_sentinel_instructions(instructions, start)
826 inverted_comparison = False
827 while True:
828 try:
829 original_i, original_inst = next(original_iter)
830 new_i, new_inst = next(new_iter)
831 except StopIteration:
832 return
833 if (
834 inverted_comparison
835 and original_inst.opname != new_inst.opname == "UNARY_NOT"
836 ):
837 new_i, new_inst = next(new_iter)
838 inverted_comparison = (
839 original_inst.opname == new_inst.opname in ("CONTAINS_OP", "IS_OP")
840 and original_inst.arg != new_inst.arg # type: ignore[attr-defined]
841 )
842 yield original_i, original_inst, new_i, new_inst
843
844
845def handle_jumps(instructions, original_instructions):
846 # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> None
847 """
848 Transforms instructions in place until it looks more like original_instructions.
849 This is only needed in 3.10+ where optimisations lead to more drastic changes
850 after the sentinel transformation.
851 Replaces JUMP instructions that aren't also present in original_instructions
852 with the sections that they jump to until a raise or return.
853 In some other cases duplication found in `original_instructions`
854 is replicated in `instructions`.
855 """
856 while True:
857 for original_i, original_inst, new_i, new_inst in walk_both_instructions(
858 original_instructions, 0, instructions, 0
859 ):
860 if opnames_match(original_inst, new_inst):
861 continue
862
863 if "JUMP" in new_inst.opname and "JUMP" not in original_inst.opname:
864 # Find where the new instruction is jumping to, ignoring
865 # instructions which have been copied in previous iterations
866 start = only(
867 i
868 for i, inst in enumerate(instructions)
869 if inst.offset == new_inst.argval
870 and not getattr(inst, "_copied", False)
871 )
872 # Replace the jump instruction with the jumped to section of instructions
873 # That section may also be deleted if it's not similarly duplicated
874 # in original_instructions
875 new_instructions = handle_jump(
876 original_instructions, original_i, instructions, start
877 )
878 assert new_instructions is not None
879 instructions[new_i : new_i + 1] = new_instructions
880 else:
881 # Extract a section of original_instructions from original_i to return/raise
882 orig_section = []
883 for section_inst in original_instructions[original_i:]:
884 orig_section.append(section_inst)
885 if section_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"):
886 break
887 else:
888 # No return/raise - this is just a mismatch we can't handle
889 raise AssertionError
890
891 instructions[new_i:new_i] = only(find_new_matching(orig_section, instructions))
892
893 # instructions has been modified, the for loop can't sensibly continue
894 # Restart it from the beginning, checking for other issues
895 break
896
897 else: # No mismatched jumps found, we're done
898 return
899
900
901def find_new_matching(orig_section, instructions):
902 # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> Iterator[List[EnhancedInstruction]]
903 """
904 Yields sections of `instructions` which match `orig_section`.
905 The yielded sections include sentinel instructions, but these
906 are ignored when checking for matches.
907 """
908 for start in range(len(instructions) - len(orig_section)):
909 indices, dup_section = zip(
910 *islice(
911 non_sentinel_instructions(instructions, start),
912 len(orig_section),
913 )
914 )
915 if len(dup_section) < len(orig_section):
916 return
917 if sections_match(orig_section, dup_section):
918 yield instructions[start:indices[-1] + 1]
919
920
921def handle_jump(original_instructions, original_start, instructions, start):
922 # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Optional[List[EnhancedInstruction]]
923 """
924 Returns the section of instructions starting at `start` and ending
925 with a RETURN_VALUE or RAISE_VARARGS instruction.
926 There should be a matching section in original_instructions starting at original_start.
927 If that section doesn't appear elsewhere in original_instructions,
928 then also delete the returned section of instructions.
929 """
930 for original_j, original_inst, new_j, new_inst in walk_both_instructions(
931 original_instructions, original_start, instructions, start
932 ):
933 assert_(opnames_match(original_inst, new_inst))
934 if original_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"):
935 inlined = deepcopy(instructions[start : new_j + 1])
936 for inl in inlined:
937 inl._copied = True
938 orig_section = original_instructions[original_start : original_j + 1]
939 if not check_duplicates(
940 original_start, orig_section, original_instructions
941 ):
942 instructions[start : new_j + 1] = []
943 return inlined
944
945 return None
946
947
948def check_duplicates(original_i, orig_section, original_instructions):
949 # type: (int, List[EnhancedInstruction], List[EnhancedInstruction]) -> bool
950 """
951 Returns True if a section of original_instructions starting somewhere other
952 than original_i and matching orig_section is found, i.e. orig_section is duplicated.
953 """
954 for dup_start in range(len(original_instructions)):
955 if dup_start == original_i:
956 continue
957 dup_section = original_instructions[dup_start : dup_start + len(orig_section)]
958 if len(dup_section) < len(orig_section):
959 return False
960 if sections_match(orig_section, dup_section):
961 return True
962
963 return False
964
965def sections_match(orig_section, dup_section):
966 # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> bool
967 """
968 Returns True if the given lists of instructions have matching linenos and opnames.
969 """
970 return all(
971 (
972 orig_inst.lineno == dup_inst.lineno
973 # POP_BLOCKs have been found to have differing linenos in innocent cases
974 or "POP_BLOCK" == orig_inst.opname == dup_inst.opname
975 )
976 and opnames_match(orig_inst, dup_inst)
977 for orig_inst, dup_inst in zip(orig_section, dup_section)
978 )
979
980
981def opnames_match(inst1, inst2):
982 # type: (Instruction, Instruction) -> bool
983 return (
984 inst1.opname == inst2.opname
985 or "JUMP" in inst1.opname
986 and "JUMP" in inst2.opname
987 or (inst1.opname == "PRINT_EXPR" and inst2.opname == "POP_TOP")
988 or (
989 inst1.opname in ("LOAD_METHOD", "LOOKUP_METHOD")
990 and inst2.opname == "LOAD_ATTR"
991 )
992 or (inst1.opname == "CALL_METHOD" and inst2.opname == "CALL_FUNCTION")
993 )
994
995
996def get_setter(node):
997 # type: (EnhancedAST) -> Optional[Callable[[ast.AST], None]]
998 parent = node.parent
999 for name, field in ast.iter_fields(parent):
1000 if field is node:
1001 def setter(new_node):
1002 # type: (ast.AST) -> None
1003 return setattr(parent, name, new_node)
1004 return setter
1005 elif isinstance(field, list):
1006 for i, item in enumerate(field):
1007 if item is node:
1008 def setter(new_node):
1009 # type: (ast.AST) -> None
1010 field[i] = new_node
1011
1012 return setter
1013 return None
1014
1015lock = RLock()
1016
1017
1018@cache
1019def statement_containing_node(node):
1020 # type: (ast.AST) -> EnhancedAST
1021 while not isinstance(node, ast.stmt):
1022 node = cast(EnhancedAST, node).parent
1023 return cast(EnhancedAST, node)
1024
1025
1026def assert_linenos(tree):
1027 # type: (ast.AST) -> Iterator[int]
1028 for node in ast.walk(tree):
1029 if (
1030 hasattr(node, 'parent') and
1031 isinstance(statement_containing_node(node), ast.Assert)
1032 ):
1033 for lineno in node_linenos(node):
1034 yield lineno
1035
1036
1037def _extract_ipython_statement(stmt):
1038 # type: (EnhancedAST) -> ast.Module
1039 # IPython separates each statement in a cell to be executed separately
1040 # So NodeFinder should only compile one statement at a time or it
1041 # will find a code mismatch.
1042 while not isinstance(stmt.parent, ast.Module):
1043 stmt = stmt.parent
1044 # use `ast.parse` instead of `ast.Module` for better portability
1045 # python3.8 changes the signature of `ast.Module`
1046 # Inspired by https://github.com/pallets/werkzeug/pull/1552/files
1047 tree = ast.parse("")
1048 tree.body = [cast(ast.stmt, stmt)]
1049 ast.copy_location(tree, stmt)
1050 return tree
1051
1052
1053def is_ipython_cell_code_name(code_name):
1054 # type: (str) -> bool
1055 return bool(re.match(r"(<module>|<cell line: \d+>)$", code_name))
1056
1057
1058def is_ipython_cell_filename(filename):
1059 # type: (str) -> bool
1060 return bool(re.search(r"<ipython-input-|[/\\]ipykernel_\d+[/\\]", filename))
1061
1062
1063def is_ipython_cell_code(code_obj):
1064 # type: (types.CodeType) -> bool
1065 return (
1066 is_ipython_cell_filename(code_obj.co_filename) and
1067 is_ipython_cell_code_name(code_obj.co_name)
1068 )
1069
1070
1071def find_node_ipython(frame, lasti, stmts, source):
1072 # type: (types.FrameType, int, Set[EnhancedAST], Source) -> Tuple[Optional[Any], Optional[Any]]
1073 node = decorator = None
1074 for stmt in stmts:
1075 tree = _extract_ipython_statement(stmt)
1076 try:
1077 node_finder = NodeFinder(frame, stmts, tree, lasti, source)
1078 if (node or decorator) and (node_finder.result or node_finder.decorator):
1079 # Found potential nodes in separate statements,
1080 # cannot resolve ambiguity, give up here
1081 return None, None
1082
1083 node = node_finder.result
1084 decorator = node_finder.decorator
1085 except Exception:
1086 pass
1087 return decorator, node
1088
1089
1090
1091def node_linenos(node):
1092 # type: (ast.AST) -> Iterator[int]
1093 if hasattr(node, "lineno"):
1094 linenos = [] # type: Sequence[int]
1095 if hasattr(node, "end_lineno") and isinstance(node, ast.expr):
1096 assert node.end_lineno is not None # type: ignore[attr-defined]
1097 linenos = range(node.lineno, node.end_lineno + 1) # type: ignore[attr-defined]
1098 else:
1099 linenos = [node.lineno] # type: ignore[attr-defined]
1100 for lineno in linenos:
1101 yield lineno
1102
1103
1104if sys.version_info >= (3, 11):
1105 from ._position_node_finder import PositionNodeFinder as NodeFinder
1106else:
1107 NodeFinder = SentinelNodeFinder
1108