1"""
2MIT License
3
4Copyright (c) 2021 Alex Hall
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23"""
24
25import __future__
26import ast
27import dis
28import inspect
29import io
30import linecache
31import re
32import sys
33import types
34from collections import defaultdict
35from copy import deepcopy
36from functools import lru_cache
37from itertools import islice
38from itertools import zip_longest
39from operator import attrgetter
40from pathlib import Path
41from threading import RLock
42from tokenize import detect_encoding
43from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Sized, Tuple, \
44 Type, TypeVar, Union, cast
45
46if TYPE_CHECKING: # pragma: no cover
47 from asttokens import ASTTokens, ASTText
48 from asttokens.asttokens import ASTTextBase
49
50
51function_node_types = (ast.FunctionDef, ast.AsyncFunctionDef) # type: Tuple[Type, ...]
52
53cache = lru_cache(maxsize=None)
54
55# Type class used to expand out the definition of AST to include fields added by this library
56# It's not actually used for anything other than type checking though!
57class EnhancedAST(ast.AST):
58 parent = None # type: EnhancedAST
59
60
61class Instruction(dis.Instruction):
62 lineno = None # type: int
63
64
65# Type class used to expand out the definition of AST to include fields added by this library
66# It's not actually used for anything other than type checking though!
67class EnhancedInstruction(Instruction):
68 _copied = None # type: bool
69
70
71
72def assert_(condition, message=""):
73 # type: (Any, str) -> None
74 """
75 Like an assert statement, but unaffected by -O
76 :param condition: value that is expected to be truthy
77 :type message: Any
78 """
79 if not condition:
80 raise AssertionError(str(message))
81
82
83def get_instructions(co):
84 # type: (types.CodeType) -> Iterator[EnhancedInstruction]
85 lineno = co.co_firstlineno
86 for inst in dis.get_instructions(co):
87 inst = cast(EnhancedInstruction, inst)
88 lineno = inst.starts_line or lineno
89 assert_(lineno)
90 inst.lineno = lineno
91 yield inst
92
93
94TESTING = 0
95
96
97class NotOneValueFound(Exception):
98 def __init__(self,msg,values=[]):
99 # type: (str, Sequence) -> None
100 self.values=values
101 super(NotOneValueFound,self).__init__(msg)
102
103T = TypeVar('T')
104
105
106def only(it):
107 # type: (Iterable[T]) -> T
108 if isinstance(it, Sized):
109 if len(it) != 1:
110 raise NotOneValueFound('Expected one value, found %s' % len(it))
111 # noinspection PyTypeChecker
112 return list(it)[0]
113
114 lst = tuple(islice(it, 2))
115 if len(lst) == 0:
116 raise NotOneValueFound('Expected one value, found 0')
117 if len(lst) > 1:
118 raise NotOneValueFound('Expected one value, found several',lst)
119 return lst[0]
120
121
122class Source(object):
123 """
124 The source code of a single file and associated metadata.
125
126 The main method of interest is the classmethod `executing(frame)`.
127
128 If you want an instance of this class, don't construct it.
129 Ideally use the classmethod `for_frame(frame)`.
130 If you don't have a frame, use `for_filename(filename [, module_globals])`.
131 These methods cache instances by filename, so at most one instance exists per filename.
132
133 Attributes:
134 - filename
135 - text
136 - lines
137 - tree: AST parsed from text, or None if text is not valid Python
138 All nodes in the tree have an extra `parent` attribute
139
140 Other methods of interest:
141 - statements_at_line
142 - asttokens
143 - code_qualname
144 """
145
146 def __init__(self, filename, lines):
147 # type: (str, Sequence[str]) -> None
148 """
149 Don't call this constructor, see the class docstring.
150 """
151
152 self.filename = filename
153 self.text = ''.join(lines)
154 self.lines = [line.rstrip('\r\n') for line in lines]
155
156 self._nodes_by_line = defaultdict(list)
157 self.tree = None
158 self._qualnames = {}
159 self._asttokens = None # type: Optional[ASTTokens]
160 self._asttext = None # type: Optional[ASTText]
161
162 try:
163 self.tree = ast.parse(self.text, filename=filename)
164 except (SyntaxError, ValueError):
165 pass
166 else:
167 for node in ast.walk(self.tree):
168 for child in ast.iter_child_nodes(node):
169 cast(EnhancedAST, child).parent = cast(EnhancedAST, node)
170 for lineno in node_linenos(node):
171 self._nodes_by_line[lineno].append(node)
172
173 visitor = QualnameVisitor()
174 visitor.visit(self.tree)
175 self._qualnames = visitor.qualnames
176
177 @classmethod
178 def for_frame(cls, frame, use_cache=True):
179 # type: (types.FrameType, bool) -> "Source"
180 """
181 Returns the `Source` object corresponding to the file the frame is executing in.
182 """
183 return cls.for_filename(frame.f_code.co_filename, frame.f_globals or {}, use_cache)
184
185 @classmethod
186 def for_filename(
187 cls,
188 filename,
189 module_globals=None,
190 use_cache=True, # noqa no longer used
191 ):
192 # type: (Union[str, Path], Optional[Dict[str, Any]], bool) -> "Source"
193 if isinstance(filename, Path):
194 filename = str(filename)
195
196 def get_lines():
197 # type: () -> List[str]
198 return linecache.getlines(cast(str, filename), module_globals)
199
200 # Save the current linecache entry, then ensure the cache is up to date.
201 entry = linecache.cache.get(filename) # type: ignore[attr-defined]
202 linecache.checkcache(filename)
203 lines = get_lines()
204 if entry is not None and not lines:
205 # There was an entry, checkcache removed it, and nothing replaced it.
206 # This means the file wasn't simply changed (because the `lines` wouldn't be empty)
207 # but rather the file was found not to exist, probably because `filename` was fake.
208 # Restore the original entry so that we still have something.
209 linecache.cache[filename] = entry # type: ignore[attr-defined]
210 lines = get_lines()
211
212 return cls._for_filename_and_lines(filename, tuple(lines))
213
214 @classmethod
215 def _for_filename_and_lines(cls, filename, lines):
216 # type: (str, Sequence[str]) -> "Source"
217 source_cache = cls._class_local('__source_cache_with_lines', {}) # type: Dict[Tuple[str, Sequence[str]], Source]
218 try:
219 return source_cache[(filename, lines)]
220 except KeyError:
221 pass
222
223 result = source_cache[(filename, lines)] = cls(filename, lines)
224 return result
225
226 @classmethod
227 def lazycache(cls, frame):
228 # type: (types.FrameType) -> None
229 linecache.lazycache(frame.f_code.co_filename, frame.f_globals)
230
231 @classmethod
232 def executing(cls, frame_or_tb):
233 # type: (Union[types.TracebackType, types.FrameType]) -> "Executing"
234 """
235 Returns an `Executing` object representing the operation
236 currently executing in the given frame or traceback object.
237 """
238 if isinstance(frame_or_tb, types.TracebackType):
239 # https://docs.python.org/3/reference/datamodel.html#traceback-objects
240 # "tb_lineno gives the line number where the exception occurred;
241 # tb_lasti indicates the precise instruction.
242 # The line number and last instruction in the traceback may differ
243 # from the line number of its frame object
244 # if the exception occurred in a try statement with no matching except clause
245 # or with a finally clause."
246 tb = frame_or_tb
247 frame = tb.tb_frame
248 lineno = tb.tb_lineno
249 lasti = tb.tb_lasti
250 else:
251 frame = frame_or_tb
252 lineno = frame.f_lineno
253 lasti = frame.f_lasti
254
255
256
257 code = frame.f_code
258 key = (code, id(code), lasti)
259 executing_cache = cls._class_local('__executing_cache', {}) # type: Dict[Tuple[types.CodeType, int, int], Any]
260
261 args = executing_cache.get(key)
262 if not args:
263 node = stmts = decorator = None
264 source = cls.for_frame(frame)
265 tree = source.tree
266 if tree:
267 try:
268 stmts = source.statements_at_line(lineno)
269 if stmts:
270 if is_ipython_cell_code(code):
271 decorator, node = find_node_ipython(frame, lasti, stmts, source)
272 else:
273 node_finder = NodeFinder(frame, stmts, tree, lasti, source)
274 node = node_finder.result
275 decorator = node_finder.decorator
276
277 if node:
278 new_stmts = {statement_containing_node(node)}
279 assert_(new_stmts <= stmts)
280 stmts = new_stmts
281 except Exception:
282 if TESTING:
283 raise
284
285 executing_cache[key] = args = source, node, stmts, decorator
286
287 return Executing(frame, *args)
288
289 @classmethod
290 def _class_local(cls, name, default):
291 # type: (str, T) -> T
292 """
293 Returns an attribute directly associated with this class
294 (as opposed to subclasses), setting default if necessary
295 """
296 # classes have a mappingproxy preventing us from using setdefault
297 result = cls.__dict__.get(name, default)
298 setattr(cls, name, result)
299 return result
300
301 @cache
302 def statements_at_line(self, lineno):
303 # type: (int) -> Set[EnhancedAST]
304 """
305 Returns the statement nodes overlapping the given line.
306
307 Returns at most one statement unless semicolons are present.
308
309 If the `text` attribute is not valid python, meaning
310 `tree` is None, returns an empty set.
311
312 Otherwise, `Source.for_frame(frame).statements_at_line(frame.f_lineno)`
313 should return at least one statement.
314 """
315
316 return {
317 statement_containing_node(node)
318 for node in
319 self._nodes_by_line[lineno]
320 }
321
322 def asttext(self):
323 # type: () -> ASTText
324 """
325 Returns an ASTText object for getting the source of specific AST nodes.
326
327 See http://asttokens.readthedocs.io/en/latest/api-index.html
328 """
329 from asttokens import ASTText # must be installed separately
330
331 if self._asttext is None:
332 self._asttext = ASTText(self.text, tree=self.tree, filename=self.filename)
333
334 return self._asttext
335
336 def asttokens(self):
337 # type: () -> ASTTokens
338 """
339 Returns an ASTTokens object for getting the source of specific AST nodes.
340
341 See http://asttokens.readthedocs.io/en/latest/api-index.html
342 """
343 import asttokens # must be installed separately
344
345 if self._asttokens is None:
346 if hasattr(asttokens, 'ASTText'):
347 self._asttokens = self.asttext().asttokens
348 else: # pragma: no cover
349 self._asttokens = asttokens.ASTTokens(self.text, tree=self.tree, filename=self.filename)
350 return self._asttokens
351
352 def _asttext_base(self):
353 # type: () -> ASTTextBase
354 import asttokens # must be installed separately
355
356 if hasattr(asttokens, 'ASTText'):
357 return self.asttext()
358 else: # pragma: no cover
359 return self.asttokens()
360
361 @staticmethod
362 def decode_source(source):
363 # type: (Union[str, bytes]) -> str
364 if isinstance(source, bytes):
365 encoding = Source.detect_encoding(source)
366 return source.decode(encoding)
367 else:
368 return source
369
370 @staticmethod
371 def detect_encoding(source):
372 # type: (bytes) -> str
373 return detect_encoding(io.BytesIO(source).readline)[0]
374
375 def code_qualname(self, code):
376 # type: (types.CodeType) -> str
377 """
378 Imitates the __qualname__ attribute of functions for code objects.
379 Given:
380
381 - A function `func`
382 - A frame `frame` for an execution of `func`, meaning:
383 `frame.f_code is func.__code__`
384
385 `Source.for_frame(frame).code_qualname(frame.f_code)`
386 will be equal to `func.__qualname__`*. Works for Python 2 as well,
387 where of course no `__qualname__` attribute exists.
388
389 Falls back to `code.co_name` if there is no appropriate qualname.
390
391 Based on https://github.com/wbolster/qualname
392
393 (* unless `func` is a lambda
394 nested inside another lambda on the same line, in which case
395 the outer lambda's qualname will be returned for the codes
396 of both lambdas)
397 """
398 assert_(code.co_filename == self.filename)
399 return self._qualnames.get((code.co_name, code.co_firstlineno), code.co_name)
400
401
402class Executing(object):
403 """
404 Information about the operation a frame is currently executing.
405
406 Generally you will just want `node`, which is the AST node being executed,
407 or None if it's unknown.
408
409 If a decorator is currently being called, then:
410 - `node` is a function or class definition
411 - `decorator` is the expression in `node.decorator_list` being called
412 - `statements == {node}`
413 """
414
415 def __init__(self, frame, source, node, stmts, decorator):
416 # type: (types.FrameType, Source, EnhancedAST, Set[ast.stmt], Optional[EnhancedAST]) -> None
417 self.frame = frame
418 self.source = source
419 self.node = node
420 self.statements = stmts
421 self.decorator = decorator
422
423 def code_qualname(self):
424 # type: () -> str
425 return self.source.code_qualname(self.frame.f_code)
426
427 def text(self):
428 # type: () -> str
429 return self.source._asttext_base().get_text(self.node)
430
431 def text_range(self):
432 # type: () -> Tuple[int, int]
433 return self.source._asttext_base().get_text_range(self.node)
434
435
436class QualnameVisitor(ast.NodeVisitor):
437 def __init__(self):
438 # type: () -> None
439 super(QualnameVisitor, self).__init__()
440 self.stack = [] # type: List[str]
441 self.qualnames = {} # type: Dict[Tuple[str, int], str]
442
443 def add_qualname(self, node, name=None):
444 # type: (ast.AST, Optional[str]) -> None
445 name = name or node.name # type: ignore[attr-defined]
446 self.stack.append(name)
447 if getattr(node, 'decorator_list', ()):
448 lineno = node.decorator_list[0].lineno # type: ignore[attr-defined]
449 else:
450 lineno = node.lineno # type: ignore[attr-defined]
451 self.qualnames.setdefault((name, lineno), ".".join(self.stack))
452
453 def visit_FunctionDef(self, node, name=None):
454 # type: (ast.AST, Optional[str]) -> None
455 assert isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)), node
456 self.add_qualname(node, name)
457 self.stack.append('<locals>')
458 children = [] # type: Sequence[ast.AST]
459 if isinstance(node, ast.Lambda):
460 children = [node.body]
461 else:
462 children = node.body
463 for child in children:
464 self.visit(child)
465 self.stack.pop()
466 self.stack.pop()
467
468 # Find lambdas in the function definition outside the body,
469 # e.g. decorators or default arguments
470 # Based on iter_child_nodes
471 for field, child in ast.iter_fields(node):
472 if field == 'body':
473 continue
474 if isinstance(child, ast.AST):
475 self.visit(child)
476 elif isinstance(child, list):
477 for grandchild in child:
478 if isinstance(grandchild, ast.AST):
479 self.visit(grandchild)
480
481 visit_AsyncFunctionDef = visit_FunctionDef
482
483 def visit_Lambda(self, node):
484 # type: (ast.AST) -> None
485 assert isinstance(node, ast.Lambda)
486 self.visit_FunctionDef(node, '<lambda>')
487
488 def visit_ClassDef(self, node):
489 # type: (ast.AST) -> None
490 assert isinstance(node, ast.ClassDef)
491 self.add_qualname(node)
492 self.generic_visit(node)
493 self.stack.pop()
494
495
496
497
498
499future_flags = sum(
500 getattr(__future__, fname).compiler_flag for fname in __future__.all_feature_names
501)
502
503
504def compile_similar_to(source, matching_code):
505 # type: (ast.Module, types.CodeType) -> Any
506 return compile(
507 source,
508 matching_code.co_filename,
509 'exec',
510 flags=future_flags & matching_code.co_flags,
511 dont_inherit=True,
512 )
513
514
515sentinel = 'io8urthglkjdghvljusketgIYRFYUVGHFRTBGVHKGF78678957647698'
516
517def is_rewritten_by_pytest(code):
518 # type: (types.CodeType) -> bool
519 return any(
520 bc.opname != "LOAD_CONST" and isinstance(bc.argval,str) and bc.argval.startswith("@py")
521 for bc in get_instructions(code)
522 )
523
524
525class SentinelNodeFinder(object):
526 result = None # type: EnhancedAST
527
528 def __init__(self, frame, stmts, tree, lasti, source):
529 # type: (types.FrameType, Set[EnhancedAST], ast.Module, int, Source) -> None
530 assert_(stmts)
531 self.frame = frame
532 self.tree = tree
533 self.code = code = frame.f_code
534 self.is_pytest = is_rewritten_by_pytest(code)
535
536 if self.is_pytest:
537 self.ignore_linenos = frozenset(assert_linenos(tree))
538 else:
539 self.ignore_linenos = frozenset()
540
541 self.decorator = None
542
543 self.instruction = instruction = self.get_actual_current_instruction(lasti)
544 op_name = instruction.opname
545 extra_filter = lambda e: True # type: Callable[[Any], bool]
546 ctx = type(None) # type: Type
547
548 typ = type(None) # type: Type
549 if op_name.startswith('CALL_'):
550 typ = ast.Call
551 elif op_name.startswith(('BINARY_SUBSCR', 'SLICE+')):
552 typ = ast.Subscript
553 ctx = ast.Load
554 elif op_name.startswith('BINARY_'):
555 typ = ast.BinOp
556 op_type = dict(
557 BINARY_POWER=ast.Pow,
558 BINARY_MULTIPLY=ast.Mult,
559 BINARY_MATRIX_MULTIPLY=getattr(ast, "MatMult", ()),
560 BINARY_FLOOR_DIVIDE=ast.FloorDiv,
561 BINARY_TRUE_DIVIDE=ast.Div,
562 BINARY_MODULO=ast.Mod,
563 BINARY_ADD=ast.Add,
564 BINARY_SUBTRACT=ast.Sub,
565 BINARY_LSHIFT=ast.LShift,
566 BINARY_RSHIFT=ast.RShift,
567 BINARY_AND=ast.BitAnd,
568 BINARY_XOR=ast.BitXor,
569 BINARY_OR=ast.BitOr,
570 )[op_name]
571 extra_filter = lambda e: isinstance(e.op, op_type)
572 elif op_name.startswith('UNARY_'):
573 typ = ast.UnaryOp
574 op_type = dict(
575 UNARY_POSITIVE=ast.UAdd,
576 UNARY_NEGATIVE=ast.USub,
577 UNARY_NOT=ast.Not,
578 UNARY_INVERT=ast.Invert,
579 )[op_name]
580 extra_filter = lambda e: isinstance(e.op, op_type)
581 elif op_name in ('LOAD_ATTR', 'LOAD_METHOD', 'LOOKUP_METHOD'):
582 typ = ast.Attribute
583 ctx = ast.Load
584 extra_filter = lambda e: attr_names_match(e.attr, instruction.argval)
585 elif op_name in ('LOAD_NAME', 'LOAD_GLOBAL', 'LOAD_FAST', 'LOAD_DEREF', 'LOAD_CLASSDEREF'):
586 typ = ast.Name
587 ctx = ast.Load
588 extra_filter = lambda e: e.id == instruction.argval
589 elif op_name in ('COMPARE_OP', 'IS_OP', 'CONTAINS_OP'):
590 typ = ast.Compare
591 extra_filter = lambda e: len(e.ops) == 1
592 elif op_name.startswith(('STORE_SLICE', 'STORE_SUBSCR')):
593 ctx = ast.Store
594 typ = ast.Subscript
595 elif op_name.startswith('STORE_ATTR'):
596 ctx = ast.Store
597 typ = ast.Attribute
598 extra_filter = lambda e: attr_names_match(e.attr, instruction.argval)
599 else:
600 raise RuntimeError(op_name)
601
602 with lock:
603 exprs = {
604 cast(EnhancedAST, node)
605 for stmt in stmts
606 for node in ast.walk(stmt)
607 if isinstance(node, typ)
608 if isinstance(getattr(node, "ctx", None), ctx)
609 if extra_filter(node)
610 if statement_containing_node(node) == stmt
611 }
612
613 if ctx == ast.Store:
614 # No special bytecode tricks here.
615 # We can handle multiple assigned attributes with different names,
616 # but only one assigned subscript.
617 self.result = only(exprs)
618 return
619
620 matching = list(self.matching_nodes(exprs))
621 if not matching and typ == ast.Call:
622 self.find_decorator(stmts)
623 else:
624 self.result = only(matching)
625
626 def find_decorator(self, stmts):
627 # type: (Union[List[EnhancedAST], Set[EnhancedAST]]) -> None
628 stmt = only(stmts)
629 assert_(isinstance(stmt, (ast.ClassDef, function_node_types)))
630 decorators = stmt.decorator_list # type: ignore[attr-defined]
631 assert_(decorators)
632 line_instructions = [
633 inst
634 for inst in self.clean_instructions(self.code)
635 if inst.lineno == self.frame.f_lineno
636 ]
637 last_decorator_instruction_index = [
638 i
639 for i, inst in enumerate(line_instructions)
640 if inst.opname == "CALL_FUNCTION"
641 ][-1]
642 assert_(
643 line_instructions[last_decorator_instruction_index + 1].opname.startswith(
644 "STORE_"
645 )
646 )
647 decorator_instructions = line_instructions[
648 last_decorator_instruction_index
649 - len(decorators)
650 + 1 : last_decorator_instruction_index
651 + 1
652 ]
653 assert_({inst.opname for inst in decorator_instructions} == {"CALL_FUNCTION"})
654 decorator_index = decorator_instructions.index(self.instruction)
655 decorator = decorators[::-1][decorator_index]
656 self.decorator = decorator
657 self.result = stmt
658
659 def clean_instructions(self, code):
660 # type: (types.CodeType) -> List[EnhancedInstruction]
661 return [
662 inst
663 for inst in get_instructions(code)
664 if inst.opname not in ("EXTENDED_ARG", "NOP")
665 if inst.lineno not in self.ignore_linenos
666 ]
667
668 def get_original_clean_instructions(self):
669 # type: () -> List[EnhancedInstruction]
670 result = self.clean_instructions(self.code)
671
672 # pypy sometimes (when is not clear)
673 # inserts JUMP_IF_NOT_DEBUG instructions in bytecode
674 # If they're not present in our compiled instructions,
675 # ignore them in the original bytecode
676 if not any(
677 inst.opname == "JUMP_IF_NOT_DEBUG"
678 for inst in self.compile_instructions()
679 ):
680 result = [
681 inst for inst in result
682 if inst.opname != "JUMP_IF_NOT_DEBUG"
683 ]
684
685 return result
686
687 def matching_nodes(self, exprs):
688 # type: (Set[EnhancedAST]) -> Iterator[EnhancedAST]
689 original_instructions = self.get_original_clean_instructions()
690 original_index = only(
691 i
692 for i, inst in enumerate(original_instructions)
693 if inst == self.instruction
694 )
695 for expr_index, expr in enumerate(exprs):
696 setter = get_setter(expr)
697 assert setter is not None
698 # noinspection PyArgumentList
699 replacement = ast.BinOp(
700 left=expr,
701 op=ast.Pow(),
702 right=ast.Str(s=sentinel),
703 )
704 ast.fix_missing_locations(replacement)
705 setter(replacement)
706 try:
707 instructions = self.compile_instructions()
708 finally:
709 setter(expr)
710
711 if sys.version_info >= (3, 10):
712 try:
713 handle_jumps(instructions, original_instructions)
714 except Exception:
715 # Give other candidates a chance
716 if TESTING or expr_index < len(exprs) - 1:
717 continue
718 raise
719
720 indices = [
721 i
722 for i, instruction in enumerate(instructions)
723 if instruction.argval == sentinel
724 ]
725
726 # There can be several indices when the bytecode is duplicated,
727 # as happens in a finally block in 3.9+
728 # First we remove the opcodes caused by our modifications
729 for index_num, sentinel_index in enumerate(indices):
730 # Adjustment for removing sentinel instructions below
731 # in past iterations
732 sentinel_index -= index_num * 2
733
734 assert_(instructions.pop(sentinel_index).opname == 'LOAD_CONST')
735 assert_(instructions.pop(sentinel_index).opname == 'BINARY_POWER')
736
737 # Then we see if any of the instruction indices match
738 for index_num, sentinel_index in enumerate(indices):
739 sentinel_index -= index_num * 2
740 new_index = sentinel_index - 1
741
742 if new_index != original_index:
743 continue
744
745 original_inst = original_instructions[original_index]
746 new_inst = instructions[new_index]
747
748 # In Python 3.9+, changing 'not x in y' to 'not sentinel_transformation(x in y)'
749 # changes a CONTAINS_OP(invert=1) to CONTAINS_OP(invert=0),<sentinel stuff>,UNARY_NOT
750 if (
751 original_inst.opname == new_inst.opname in ('CONTAINS_OP', 'IS_OP')
752 and original_inst.arg != new_inst.arg # type: ignore[attr-defined]
753 and (
754 original_instructions[original_index + 1].opname
755 != instructions[new_index + 1].opname == 'UNARY_NOT'
756 )):
757 # Remove the difference for the upcoming assert
758 instructions.pop(new_index + 1)
759
760 # Check that the modified instructions don't have anything unexpected
761 # 3.10 is a bit too weird to assert this in all cases but things still work
762 if sys.version_info < (3, 10):
763 for inst1, inst2 in zip_longest(
764 original_instructions, instructions
765 ):
766 assert_(inst1 and inst2 and opnames_match(inst1, inst2))
767
768 yield expr
769
770 def compile_instructions(self):
771 # type: () -> List[EnhancedInstruction]
772 module_code = compile_similar_to(self.tree, self.code)
773 code = only(self.find_codes(module_code))
774 return self.clean_instructions(code)
775
776 def find_codes(self, root_code):
777 # type: (types.CodeType) -> list
778 checks = [
779 attrgetter('co_firstlineno'),
780 attrgetter('co_freevars'),
781 attrgetter('co_cellvars'),
782 lambda c: is_ipython_cell_code_name(c.co_name) or c.co_name,
783 ] # type: List[Callable]
784 if not self.is_pytest:
785 checks += [
786 attrgetter('co_names'),
787 attrgetter('co_varnames'),
788 ]
789
790 def matches(c):
791 # type: (types.CodeType) -> bool
792 return all(
793 f(c) == f(self.code)
794 for f in checks
795 )
796
797 code_options = []
798 if matches(root_code):
799 code_options.append(root_code)
800
801 def finder(code):
802 # type: (types.CodeType) -> None
803 for const in code.co_consts:
804 if not inspect.iscode(const):
805 continue
806
807 if matches(const):
808 code_options.append(const)
809 finder(const)
810
811 finder(root_code)
812 return code_options
813
814 def get_actual_current_instruction(self, lasti):
815 # type: (int) -> EnhancedInstruction
816 """
817 Get the instruction corresponding to the current
818 frame offset, skipping EXTENDED_ARG instructions
819 """
820 # Don't use get_original_clean_instructions
821 # because we need the actual instructions including
822 # EXTENDED_ARG
823 instructions = list(get_instructions(self.code))
824 index = only(
825 i
826 for i, inst in enumerate(instructions)
827 if inst.offset == lasti
828 )
829
830 while True:
831 instruction = instructions[index]
832 if instruction.opname != "EXTENDED_ARG":
833 return instruction
834 index += 1
835
836
837
838def non_sentinel_instructions(instructions, start):
839 # type: (List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction]]
840 """
841 Yields (index, instruction) pairs excluding the basic
842 instructions introduced by the sentinel transformation
843 """
844 skip_power = False
845 for i, inst in islice(enumerate(instructions), start, None):
846 if inst.argval == sentinel:
847 assert_(inst.opname == "LOAD_CONST")
848 skip_power = True
849 continue
850 elif skip_power:
851 assert_(inst.opname == "BINARY_POWER")
852 skip_power = False
853 continue
854 yield i, inst
855
856
857def walk_both_instructions(original_instructions, original_start, instructions, start):
858 # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction, int, EnhancedInstruction]]
859 """
860 Yields matching indices and instructions from the new and original instructions,
861 leaving out changes made by the sentinel transformation.
862 """
863 original_iter = islice(enumerate(original_instructions), original_start, None)
864 new_iter = non_sentinel_instructions(instructions, start)
865 inverted_comparison = False
866 while True:
867 try:
868 original_i, original_inst = next(original_iter)
869 new_i, new_inst = next(new_iter)
870 except StopIteration:
871 return
872 if (
873 inverted_comparison
874 and original_inst.opname != new_inst.opname == "UNARY_NOT"
875 ):
876 new_i, new_inst = next(new_iter)
877 inverted_comparison = (
878 original_inst.opname == new_inst.opname in ("CONTAINS_OP", "IS_OP")
879 and original_inst.arg != new_inst.arg # type: ignore[attr-defined]
880 )
881 yield original_i, original_inst, new_i, new_inst
882
883
884def handle_jumps(instructions, original_instructions):
885 # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> None
886 """
887 Transforms instructions in place until it looks more like original_instructions.
888 This is only needed in 3.10+ where optimisations lead to more drastic changes
889 after the sentinel transformation.
890 Replaces JUMP instructions that aren't also present in original_instructions
891 with the sections that they jump to until a raise or return.
892 In some other cases duplication found in `original_instructions`
893 is replicated in `instructions`.
894 """
895 while True:
896 for original_i, original_inst, new_i, new_inst in walk_both_instructions(
897 original_instructions, 0, instructions, 0
898 ):
899 if opnames_match(original_inst, new_inst):
900 continue
901
902 if "JUMP" in new_inst.opname and "JUMP" not in original_inst.opname:
903 # Find where the new instruction is jumping to, ignoring
904 # instructions which have been copied in previous iterations
905 start = only(
906 i
907 for i, inst in enumerate(instructions)
908 if inst.offset == new_inst.argval
909 and not getattr(inst, "_copied", False)
910 )
911 # Replace the jump instruction with the jumped to section of instructions
912 # That section may also be deleted if it's not similarly duplicated
913 # in original_instructions
914 new_instructions = handle_jump(
915 original_instructions, original_i, instructions, start
916 )
917 assert new_instructions is not None
918 instructions[new_i : new_i + 1] = new_instructions
919 else:
920 # Extract a section of original_instructions from original_i to return/raise
921 orig_section = []
922 for section_inst in original_instructions[original_i:]:
923 orig_section.append(section_inst)
924 if section_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"):
925 break
926 else:
927 # No return/raise - this is just a mismatch we can't handle
928 raise AssertionError
929
930 instructions[new_i:new_i] = only(find_new_matching(orig_section, instructions))
931
932 # instructions has been modified, the for loop can't sensibly continue
933 # Restart it from the beginning, checking for other issues
934 break
935
936 else: # No mismatched jumps found, we're done
937 return
938
939
940def find_new_matching(orig_section, instructions):
941 # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> Iterator[List[EnhancedInstruction]]
942 """
943 Yields sections of `instructions` which match `orig_section`.
944 The yielded sections include sentinel instructions, but these
945 are ignored when checking for matches.
946 """
947 for start in range(len(instructions) - len(orig_section)):
948 indices, dup_section = zip(
949 *islice(
950 non_sentinel_instructions(instructions, start),
951 len(orig_section),
952 )
953 )
954 if len(dup_section) < len(orig_section):
955 return
956 if sections_match(orig_section, dup_section):
957 yield instructions[start:indices[-1] + 1]
958
959
960def handle_jump(original_instructions, original_start, instructions, start):
961 # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Optional[List[EnhancedInstruction]]
962 """
963 Returns the section of instructions starting at `start` and ending
964 with a RETURN_VALUE or RAISE_VARARGS instruction.
965 There should be a matching section in original_instructions starting at original_start.
966 If that section doesn't appear elsewhere in original_instructions,
967 then also delete the returned section of instructions.
968 """
969 for original_j, original_inst, new_j, new_inst in walk_both_instructions(
970 original_instructions, original_start, instructions, start
971 ):
972 assert_(opnames_match(original_inst, new_inst))
973 if original_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"):
974 inlined = deepcopy(instructions[start : new_j + 1])
975 for inl in inlined:
976 inl._copied = True
977 orig_section = original_instructions[original_start : original_j + 1]
978 if not check_duplicates(
979 original_start, orig_section, original_instructions
980 ):
981 instructions[start : new_j + 1] = []
982 return inlined
983
984 return None
985
986
987def check_duplicates(original_i, orig_section, original_instructions):
988 # type: (int, List[EnhancedInstruction], List[EnhancedInstruction]) -> bool
989 """
990 Returns True if a section of original_instructions starting somewhere other
991 than original_i and matching orig_section is found, i.e. orig_section is duplicated.
992 """
993 for dup_start in range(len(original_instructions)):
994 if dup_start == original_i:
995 continue
996 dup_section = original_instructions[dup_start : dup_start + len(orig_section)]
997 if len(dup_section) < len(orig_section):
998 return False
999 if sections_match(orig_section, dup_section):
1000 return True
1001
1002 return False
1003
1004def sections_match(orig_section, dup_section):
1005 # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> bool
1006 """
1007 Returns True if the given lists of instructions have matching linenos and opnames.
1008 """
1009 return all(
1010 (
1011 orig_inst.lineno == dup_inst.lineno
1012 # POP_BLOCKs have been found to have differing linenos in innocent cases
1013 or "POP_BLOCK" == orig_inst.opname == dup_inst.opname
1014 )
1015 and opnames_match(orig_inst, dup_inst)
1016 for orig_inst, dup_inst in zip(orig_section, dup_section)
1017 )
1018
1019
1020def opnames_match(inst1, inst2):
1021 # type: (Instruction, Instruction) -> bool
1022 return (
1023 inst1.opname == inst2.opname
1024 or "JUMP" in inst1.opname
1025 and "JUMP" in inst2.opname
1026 or (inst1.opname == "PRINT_EXPR" and inst2.opname == "POP_TOP")
1027 or (
1028 inst1.opname in ("LOAD_METHOD", "LOOKUP_METHOD")
1029 and inst2.opname == "LOAD_ATTR"
1030 )
1031 or (inst1.opname == "CALL_METHOD" and inst2.opname == "CALL_FUNCTION")
1032 )
1033
1034
1035def get_setter(node):
1036 # type: (EnhancedAST) -> Optional[Callable[[ast.AST], None]]
1037 parent = node.parent
1038 for name, field in ast.iter_fields(parent):
1039 if field is node:
1040 def setter(new_node):
1041 # type: (ast.AST) -> None
1042 return setattr(parent, name, new_node)
1043 return setter
1044 elif isinstance(field, list):
1045 for i, item in enumerate(field):
1046 if item is node:
1047 def setter(new_node):
1048 # type: (ast.AST) -> None
1049 field[i] = new_node
1050
1051 return setter
1052 return None
1053
1054lock = RLock()
1055
1056
1057@cache
1058def statement_containing_node(node):
1059 # type: (ast.AST) -> EnhancedAST
1060 while not isinstance(node, ast.stmt):
1061 node = cast(EnhancedAST, node).parent
1062 return cast(EnhancedAST, node)
1063
1064
1065def assert_linenos(tree):
1066 # type: (ast.AST) -> Iterator[int]
1067 for node in ast.walk(tree):
1068 if (
1069 hasattr(node, 'parent') and
1070 isinstance(statement_containing_node(node), ast.Assert)
1071 ):
1072 for lineno in node_linenos(node):
1073 yield lineno
1074
1075
1076def _extract_ipython_statement(stmt):
1077 # type: (EnhancedAST) -> ast.Module
1078 # IPython separates each statement in a cell to be executed separately
1079 # So NodeFinder should only compile one statement at a time or it
1080 # will find a code mismatch.
1081 while not isinstance(stmt.parent, ast.Module):
1082 stmt = stmt.parent
1083 # use `ast.parse` instead of `ast.Module` for better portability
1084 # python3.8 changes the signature of `ast.Module`
1085 # Inspired by https://github.com/pallets/werkzeug/pull/1552/files
1086 tree = ast.parse("")
1087 tree.body = [cast(ast.stmt, stmt)]
1088 ast.copy_location(tree, stmt)
1089 return tree
1090
1091
1092def is_ipython_cell_code_name(code_name):
1093 # type: (str) -> bool
1094 return bool(re.match(r"(<module>|<cell line: \d+>)$", code_name))
1095
1096
1097def is_ipython_cell_filename(filename):
1098 # type: (str) -> bool
1099 return bool(re.search(r"<ipython-input-|[/\\]ipykernel_\d+[/\\]", filename))
1100
1101
1102def is_ipython_cell_code(code_obj):
1103 # type: (types.CodeType) -> bool
1104 return (
1105 is_ipython_cell_filename(code_obj.co_filename) and
1106 is_ipython_cell_code_name(code_obj.co_name)
1107 )
1108
1109
1110def find_node_ipython(frame, lasti, stmts, source):
1111 # type: (types.FrameType, int, Set[EnhancedAST], Source) -> Tuple[Optional[Any], Optional[Any]]
1112 node = decorator = None
1113 for stmt in stmts:
1114 tree = _extract_ipython_statement(stmt)
1115 try:
1116 node_finder = NodeFinder(frame, stmts, tree, lasti, source)
1117 if (node or decorator) and (node_finder.result or node_finder.decorator):
1118 # Found potential nodes in separate statements,
1119 # cannot resolve ambiguity, give up here
1120 return None, None
1121
1122 node = node_finder.result
1123 decorator = node_finder.decorator
1124 except Exception:
1125 pass
1126 return decorator, node
1127
1128
1129def attr_names_match(attr, argval):
1130 # type: (str, str) -> bool
1131 """
1132 Checks that the user-visible attr (from ast) can correspond to
1133 the argval in the bytecode, i.e. the real attribute fetched internally,
1134 which may be mangled for private attributes.
1135 """
1136 if attr == argval:
1137 return True
1138 if not attr.startswith("__"):
1139 return False
1140 return bool(re.match(r"^_\w+%s$" % attr, argval))
1141
1142
1143def node_linenos(node):
1144 # type: (ast.AST) -> Iterator[int]
1145 if hasattr(node, "lineno"):
1146 linenos = [] # type: Sequence[int]
1147 if hasattr(node, "end_lineno") and isinstance(node, ast.expr):
1148 assert node.end_lineno is not None # type: ignore[attr-defined]
1149 linenos = range(node.lineno, node.end_lineno + 1) # type: ignore[attr-defined]
1150 else:
1151 linenos = [node.lineno] # type: ignore[attr-defined]
1152 for lineno in linenos:
1153 yield lineno
1154
1155
1156if sys.version_info >= (3, 11):
1157 from ._position_node_finder import PositionNodeFinder as NodeFinder
1158else:
1159 NodeFinder = SentinelNodeFinder
1160