Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/libcst/codemod/visitors/_apply_type_annotations.py: 23%
566 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:43 +0000
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
6from collections import defaultdict
7from dataclasses import dataclass
8from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
10import libcst as cst
11import libcst.matchers as m
13from libcst.codemod._context import CodemodContext
14from libcst.codemod._visitor import ContextAwareTransformer
15from libcst.codemod.visitors._add_imports import AddImportsVisitor
16from libcst.codemod.visitors._gather_global_names import GatherGlobalNamesVisitor
17from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
18from libcst.codemod.visitors._imports import ImportItem
19from libcst.helpers import get_full_name_for_node
20from libcst.metadata import PositionProvider, QualifiedNameProvider
23NameOrAttribute = Union[cst.Name, cst.Attribute]
24NAME_OR_ATTRIBUTE = (cst.Name, cst.Attribute)
25# Union type for *args and **args
26StarParamType = Union[
27 None,
28 cst._maybe_sentinel.MaybeSentinel,
29 cst._nodes.expression.Param,
30 cst._nodes.expression.ParamStar,
31]
34def _module_and_target(qualified_name: str) -> Tuple[str, str]:
35 relative_prefix = ""
36 while qualified_name.startswith("."):
37 relative_prefix += "."
38 qualified_name = qualified_name[1:]
39 split = qualified_name.rsplit(".", 1)
40 if len(split) == 1:
41 qualifier, target = "", split[0]
42 else:
43 qualifier, target = split
44 return (relative_prefix + qualifier, target)
47def _get_unique_qualified_name(
48 visitor: m.MatcherDecoratableVisitor, node: cst.CSTNode
49) -> str:
50 name = None
51 names = [q.name for q in visitor.get_metadata(QualifiedNameProvider, node)]
52 if len(names) == 0:
53 # we hit this branch if the stub is directly using a fully
54 # qualified name, which is not technically valid python but is
55 # convenient to allow.
56 name = get_full_name_for_node(node)
57 elif len(names) == 1 and isinstance(names[0], str):
58 name = names[0]
59 if name is None:
60 start = visitor.get_metadata(PositionProvider, node).start
61 raise ValueError(
62 "Could not resolve a unique qualified name for type "
63 + f"{get_full_name_for_node(node)} at {start.line}:{start.column}. "
64 + f"Candidate names were: {names!r}"
65 )
66 return name
69def _get_import_alias_names(
70 import_aliases: Sequence[cst.ImportAlias],
71) -> Set[str]:
72 import_names = set()
73 for imported_name in import_aliases:
74 asname = imported_name.asname
75 if asname is not None:
76 import_names.add(get_full_name_for_node(asname.name))
77 else:
78 import_names.add(get_full_name_for_node(imported_name.name))
79 return import_names
82def _get_imported_names(
83 imports: Sequence[Union[cst.Import, cst.ImportFrom]],
84) -> Set[str]:
85 """
86 Given a series of import statements (both Import and ImportFrom),
87 determine all of the names that have been imported into the current
88 scope. For example:
89 - ``import foo.bar as bar, foo.baz`` produces ``{'bar', 'foo.baz'}``
90 - ``from foo import (Bar, Baz as B)`` produces ``{'Bar', 'B'}``
91 - ``from foo import *`` produces ``set()` because we cannot resolve names
92 """
93 import_names = set()
94 for _import in imports:
95 if isinstance(_import, cst.Import):
96 import_names.update(_get_import_alias_names(_import.names))
97 else:
98 names = _import.names
99 if not isinstance(names, cst.ImportStar):
100 import_names.update(_get_import_alias_names(names))
101 return import_names
104def _is_non_sentinel(
105 x: Union[None, cst.CSTNode, cst.MaybeSentinel],
106) -> bool:
107 return x is not None and x != cst.MaybeSentinel.DEFAULT
110def _get_string_value(
111 node: cst.SimpleString,
112) -> str:
113 s = node.value
114 c = s[-1]
115 return s[s.index(c) : -1]
118def _find_generic_base(
119 node: cst.ClassDef,
120) -> Optional[cst.Arg]:
121 for b in node.bases:
122 if m.matches(b.value, m.Subscript(value=m.Name("Generic"))):
123 return b
126@dataclass(frozen=True)
127class FunctionKey:
128 """
129 Class representing a funciton name and signature.
131 This exists to ensure we do not attempt to apply stubs to functions whose
132 definition is incompatible.
133 """
135 name: str
136 pos: int
137 kwonly: str
138 posonly: int
139 star_arg: bool
140 star_kwarg: bool
142 @classmethod
143 def make(
144 cls,
145 name: str,
146 params: cst.Parameters,
147 ) -> "FunctionKey":
148 pos = len(params.params)
149 kwonly = ",".join(sorted(x.name.value for x in params.kwonly_params))
150 posonly = len(params.posonly_params)
151 star_arg = _is_non_sentinel(params.star_arg)
152 star_kwarg = _is_non_sentinel(params.star_kwarg)
153 return cls(
154 name,
155 pos,
156 kwonly,
157 posonly,
158 star_arg,
159 star_kwarg,
160 )
163@dataclass(frozen=True)
164class FunctionAnnotation:
165 parameters: cst.Parameters
166 returns: Optional[cst.Annotation]
169@dataclass
170class Annotations:
171 """
172 Represents all of the annotation information we might add to
173 a class:
174 - All data is keyed on the qualified name relative to the module root
175 - The ``functions`` field also keys on the signature so that we
176 do not apply stub types where the signature is incompatible.
178 The idea is that
179 - ``functions`` contains all function and method type
180 information from the stub, and the qualifier for a method includes
181 the containing class names (e.g. "Cat.meow")
182 - ``attributes`` similarly contains all globals
183 and class-level attribute type information.
184 - The ``class_definitions`` field contains all of the classes
185 defined in the stub. Most of these classes will be ignored in
186 downstream logic (it is *not* used to annotate attributes or
187 method), but there are some cases like TypedDict where a
188 typing-only class needs to be injected.
189 - The field ``typevars`` contains the assign statement for all
190 type variables in the stub, and ``names`` tracks
191 all of the names used in annotations; together these fields
192 tell us which typevars should be included in the codemod
193 (all typevars that appear in annotations.)
194 """
196 # TODO: consider simplifying this in a few ways:
197 # - We could probably just inject all typevars, used or not.
198 # It doesn't seem to me that our codemod needs to act like
199 # a linter checking for unused names.
200 # - We could probably decide which classes are typing-only
201 # in the visitor rather than the codemod, which would make
202 # it easier to reason locally about (and document) how the
203 # class_definitions field works.
205 functions: Dict[FunctionKey, FunctionAnnotation]
206 attributes: Dict[str, cst.Annotation]
207 class_definitions: Dict[str, cst.ClassDef]
208 typevars: Dict[str, cst.Assign]
209 names: Set[str]
211 @classmethod
212 def empty(cls) -> "Annotations":
213 return Annotations({}, {}, {}, {}, set())
215 def update(self, other: "Annotations") -> None:
216 self.functions.update(other.functions)
217 self.attributes.update(other.attributes)
218 self.class_definitions.update(other.class_definitions)
219 self.typevars.update(other.typevars)
220 self.names.update(other.names)
222 def finish(self) -> None:
223 self.typevars = {k: v for k, v in self.typevars.items() if k in self.names}
226@dataclass(frozen=True)
227class ImportedSymbol:
228 """Import of foo.Bar, where both foo and Bar are potentially aliases."""
230 module_name: str
231 module_alias: Optional[str] = None
232 target_name: Optional[str] = None
233 target_alias: Optional[str] = None
235 @property
236 def symbol(self) -> Optional[str]:
237 return self.target_alias or self.target_name
239 @property
240 def module_symbol(self) -> str:
241 return self.module_alias or self.module_name
244class ImportedSymbolCollector(m.MatcherDecoratableVisitor):
245 """
246 Collect imported symbols from a stub module.
247 """
249 METADATA_DEPENDENCIES = (
250 PositionProvider,
251 QualifiedNameProvider,
252 )
254 def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None:
255 super().__init__()
256 self.existing_imports: Set[str] = existing_imports
257 self.imported_symbols: Dict[str, Set[ImportedSymbol]] = defaultdict(set)
258 self.in_annotation: bool = False
260 def visit_Annotation(self, node: cst.Annotation) -> None:
261 self.in_annotation = True
263 def leave_Annotation(self, original_node: cst.Annotation) -> None:
264 self.in_annotation = False
266 def visit_ClassDef(self, node: cst.ClassDef) -> None:
267 for base in node.bases:
268 value = base.value
269 if isinstance(value, NAME_OR_ATTRIBUTE):
270 self._handle_NameOrAttribute(value)
272 def visit_Name(self, node: cst.Name) -> None:
273 if self.in_annotation:
274 self._handle_NameOrAttribute(node)
276 def visit_Attribute(self, node: cst.Attribute) -> None:
277 if self.in_annotation:
278 self._handle_NameOrAttribute(node)
280 def visit_Subscript(self, node: cst.Subscript) -> bool:
281 if isinstance(node.value, NAME_OR_ATTRIBUTE):
282 return True
283 return _get_unique_qualified_name(self, node) not in ("Type", "typing.Type")
285 def _handle_NameOrAttribute(
286 self,
287 node: NameOrAttribute,
288 ) -> None:
289 # Adds the qualified name to the list of imported symbols
290 obj = sym = None # keep pyre happy
291 if isinstance(node, cst.Name):
292 obj = None
293 sym = node.value
294 elif isinstance(node, cst.Attribute):
295 obj = node.value.value # pyre-ignore[16]
296 sym = node.attr.value
297 qualified_name = _get_unique_qualified_name(self, node)
298 module, target = _module_and_target(qualified_name)
299 if module in ("", "builtins"):
300 return
301 elif qualified_name not in self.existing_imports:
302 mod = ImportedSymbol(
303 module_name=module,
304 module_alias=obj if obj != module else None,
305 target_name=target,
306 target_alias=sym if sym != target else None,
307 )
308 self.imported_symbols[sym].add(mod)
311class TypeCollector(m.MatcherDecoratableVisitor):
312 """
313 Collect type annotations from a stub module.
314 """
316 METADATA_DEPENDENCIES = (
317 PositionProvider,
318 QualifiedNameProvider,
319 )
321 annotations: Annotations
323 def __init__(
324 self,
325 existing_imports: Set[str],
326 module_imports: Dict[str, ImportItem],
327 context: CodemodContext,
328 ) -> None:
329 super().__init__()
330 self.context = context
331 # Existing imports, determined by looking at the target module.
332 # Used to help us determine when a type in a stub will require new imports.
333 #
334 # The contents of this are fully-qualified names of types in scope
335 # as well as module names, although downstream we effectively ignore
336 # the module names as of the current implementation.
337 self.existing_imports: Set[str] = existing_imports
338 # Module imports, gathered by prescanning the stub file to determine
339 # which modules need to be imported directly to qualify their symbols.
340 self.module_imports: Dict[str, ImportItem] = module_imports
341 # Fields that help us track temporary state as we recurse
342 self.qualifier: List[str] = []
343 self.current_assign: Optional[cst.Assign] = None # used to collect typevars
344 # Store the annotations.
345 self.annotations = Annotations.empty()
347 def visit_ClassDef(
348 self,
349 node: cst.ClassDef,
350 ) -> None:
351 self.qualifier.append(node.name.value)
352 new_bases = []
353 for base in node.bases:
354 value = base.value
355 if isinstance(value, NAME_OR_ATTRIBUTE):
356 new_value = value.visit(_TypeCollectorDequalifier(self))
357 elif isinstance(value, cst.Subscript):
358 new_value = value.visit(_TypeCollectorDequalifier(self))
359 else:
360 start = self.get_metadata(PositionProvider, node).start
361 raise ValueError(
362 "Invalid type used as base class in stub file at "
363 + f"{start.line}:{start.column}. Only subscripts, names, and "
364 + "attributes are valid base classes for static typing."
365 )
366 new_bases.append(base.with_changes(value=new_value))
368 self.annotations.class_definitions[node.name.value] = node.with_changes(
369 bases=new_bases
370 )
372 def leave_ClassDef(
373 self,
374 original_node: cst.ClassDef,
375 ) -> None:
376 self.qualifier.pop()
378 def visit_FunctionDef(
379 self,
380 node: cst.FunctionDef,
381 ) -> bool:
382 self.qualifier.append(node.name.value)
383 returns = node.returns
384 return_annotation = (
385 returns.visit(_TypeCollectorDequalifier(self))
386 if returns is not None
387 else None
388 )
389 assert return_annotation is None or isinstance(
390 return_annotation, cst.Annotation
391 )
392 parameter_annotations = self._handle_Parameters(node.params)
393 name = ".".join(self.qualifier)
394 key = FunctionKey.make(name, node.params)
395 self.annotations.functions[key] = FunctionAnnotation(
396 parameters=parameter_annotations, returns=return_annotation
397 )
399 # pyi files don't support inner functions, return False to stop the traversal.
400 return False
402 def leave_FunctionDef(
403 self,
404 original_node: cst.FunctionDef,
405 ) -> None:
406 self.qualifier.pop()
408 def visit_AnnAssign(
409 self,
410 node: cst.AnnAssign,
411 ) -> bool:
412 name = get_full_name_for_node(node.target)
413 if name is not None:
414 self.qualifier.append(name)
415 annotation_value = node.annotation.visit(_TypeCollectorDequalifier(self))
416 assert isinstance(annotation_value, cst.Annotation)
417 self.annotations.attributes[".".join(self.qualifier)] = annotation_value
418 return True
420 def leave_AnnAssign(
421 self,
422 original_node: cst.AnnAssign,
423 ) -> None:
424 self.qualifier.pop()
426 def visit_Assign(
427 self,
428 node: cst.Assign,
429 ) -> None:
430 self.current_assign = node
432 def leave_Assign(
433 self,
434 original_node: cst.Assign,
435 ) -> None:
436 self.current_assign = None
438 @m.call_if_inside(m.Assign())
439 @m.visit(m.Call(func=m.Name("TypeVar")))
440 def record_typevar(
441 self,
442 node: cst.Call,
443 ) -> None:
444 # pyre-ignore current_assign is never None here
445 name = get_full_name_for_node(self.current_assign.targets[0].target)
446 if name is not None:
447 # pyre-ignore current_assign is never None here
448 self.annotations.typevars[name] = self.current_assign
449 self._handle_qualification_and_should_qualify("typing.TypeVar")
450 self.current_assign = None
452 def leave_Module(
453 self,
454 original_node: cst.Module,
455 ) -> None:
456 self.annotations.finish()
458 def _module_and_target(
459 self,
460 qualified_name: str,
461 ) -> Tuple[str, str]:
462 relative_prefix = ""
463 while qualified_name.startswith("."):
464 relative_prefix += "."
465 qualified_name = qualified_name[1:]
466 split = qualified_name.rsplit(".", 1)
467 if len(split) == 1:
468 qualifier, target = "", split[0]
469 else:
470 qualifier, target = split
471 return (relative_prefix + qualifier, target)
473 def _handle_qualification_and_should_qualify(
474 self, qualified_name: str, node: Optional[cst.CSTNode] = None
475 ) -> bool:
476 """
477 Based on a qualified name and the existing module imports, record that
478 we need to add an import if necessary and return whether or not we
479 should use the qualified name due to a preexisting import.
480 """
481 module, target = self._module_and_target(qualified_name)
482 if module in ("", "builtins"):
483 return False
484 elif qualified_name not in self.existing_imports:
485 if module in self.existing_imports:
486 return True
487 elif module in self.module_imports:
488 m = self.module_imports[module]
489 if m.obj_name is None:
490 asname = m.alias
491 else:
492 asname = None
493 AddImportsVisitor.add_needed_import(
494 self.context, m.module_name, asname=asname
495 )
496 return True
497 else:
498 if node and isinstance(node, cst.Name) and node.value != target:
499 asname = node.value
500 else:
501 asname = None
502 AddImportsVisitor.add_needed_import(
503 self.context,
504 module,
505 target,
506 asname=asname,
507 )
508 return False
509 return False
511 # Handler functions
513 def _handle_Parameters(
514 self,
515 parameters: cst.Parameters,
516 ) -> cst.Parameters:
517 def update_annotations(
518 parameters: Sequence[cst.Param],
519 ) -> List[cst.Param]:
520 updated_parameters = []
521 for parameter in list(parameters):
522 annotation = parameter.annotation
523 if annotation is not None:
524 parameter = parameter.with_changes(
525 annotation=annotation.visit(_TypeCollectorDequalifier(self))
526 )
527 updated_parameters.append(parameter)
528 return updated_parameters
530 return parameters.with_changes(params=update_annotations(parameters.params))
533class _TypeCollectorDequalifier(cst.CSTTransformer):
534 def __init__(self, type_collector: "TypeCollector") -> None:
535 self.type_collector = type_collector
537 def leave_Name(self, original_node: cst.Name, updated_node: cst.Name) -> cst.Name:
538 qualified_name = _get_unique_qualified_name(self.type_collector, original_node)
539 should_qualify = self.type_collector._handle_qualification_and_should_qualify(
540 qualified_name, original_node
541 )
542 self.type_collector.annotations.names.add(qualified_name)
543 if should_qualify:
544 qualified_node = cst.parse_module(qualified_name)
545 return qualified_node # pyre-ignore[7]
546 else:
547 return original_node
549 def visit_Attribute(self, node: cst.Attribute) -> bool:
550 return False
552 def leave_Attribute(
553 self, original_node: cst.Attribute, updated_node: cst.Attribute
554 ) -> cst.BaseExpression:
555 qualified_name = _get_unique_qualified_name(self.type_collector, original_node)
556 should_qualify = self.type_collector._handle_qualification_and_should_qualify(
557 qualified_name, original_node
558 )
559 self.type_collector.annotations.names.add(qualified_name)
560 if should_qualify:
561 return original_node
562 else:
563 return original_node.attr
565 def leave_Index(
566 self, original_node: cst.Index, updated_node: cst.Index
567 ) -> cst.Index:
568 if isinstance(original_node.value, cst.SimpleString):
569 self.type_collector.annotations.names.add(
570 _get_string_value(original_node.value)
571 )
572 return updated_node
574 def visit_Subscript(self, node: cst.Subscript) -> bool:
575 return _get_unique_qualified_name(self.type_collector, node) not in (
576 "Type",
577 "typing.Type",
578 )
580 def leave_Subscript(
581 self, original_node: cst.Subscript, updated_node: cst.Subscript
582 ) -> cst.Subscript:
583 if _get_unique_qualified_name(self.type_collector, original_node) in (
584 "Type",
585 "typing.Type",
586 ):
587 # Note: we are intentionally not handling qualification of
588 # anything inside `Type` because it's common to have nested
589 # classes, which we cannot currently distinguish from classes
590 # coming from other modules, appear here.
591 return original_node.with_changes(value=original_node.value.visit(self))
592 return updated_node
595@dataclass
596class AnnotationCounts:
597 global_annotations: int = 0
598 attribute_annotations: int = 0
599 parameter_annotations: int = 0
600 return_annotations: int = 0
601 classes_added: int = 0
602 typevars_and_generics_added: int = 0
604 def any_changes_applied(self) -> bool:
605 return (
606 self.global_annotations
607 + self.attribute_annotations
608 + self.parameter_annotations
609 + self.return_annotations
610 + self.classes_added
611 + self.typevars_and_generics_added
612 ) > 0
615class ApplyTypeAnnotationsVisitor(ContextAwareTransformer):
616 """
617 Apply type annotations to a source module using the given stub mdules.
618 You can also pass in explicit annotations for functions and attributes and
619 pass in new class definitions that need to be added to the source module.
621 This is one of the transforms that is available automatically to you when
622 running a codemod. To use it in this manner, import
623 :class:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor` and then call
624 the static
625 :meth:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor.store_stub_in_context`
626 method, giving it the current context (found as ``self.context`` for all
627 subclasses of :class:`~libcst.codemod.Codemod`), the stub module from which
628 you wish to add annotations.
630 For example, you can store the type annotation ``int`` for ``x`` using::
632 stub_module = parse_module("x: int = ...")
634 ApplyTypeAnnotationsVisitor.store_stub_in_context(self.context, stub_module)
636 You can apply the type annotation using::
638 source_module = parse_module("x = 1")
639 ApplyTypeAnnotationsVisitor.transform_module(source_module)
641 This will produce the following code::
643 x: int = 1
645 If the function or attribute already has a type annotation, it will not be
646 overwritten.
648 To overwrite existing annotations when applying annotations from a stub,
649 use the keyword argument ``overwrite_existing_annotations=True`` when
650 constructing the codemod or when calling ``store_stub_in_context``.
651 """
653 CONTEXT_KEY = "ApplyTypeAnnotationsVisitor"
655 def __init__(
656 self,
657 context: CodemodContext,
658 annotations: Optional[Annotations] = None,
659 overwrite_existing_annotations: bool = False,
660 use_future_annotations: bool = False,
661 strict_posargs_matching: bool = True,
662 strict_annotation_matching: bool = False,
663 always_qualify_annotations: bool = False,
664 ) -> None:
665 super().__init__(context)
666 # Qualifier for storing the canonical name of the current function.
667 self.qualifier: List[str] = []
668 self.annotations: Annotations = (
669 Annotations.empty() if annotations is None else annotations
670 )
671 self.toplevel_annotations: Dict[str, cst.Annotation] = {}
672 self.visited_classes: Set[str] = set()
673 self.overwrite_existing_annotations = overwrite_existing_annotations
674 self.use_future_annotations = use_future_annotations
675 self.strict_posargs_matching = strict_posargs_matching
676 self.strict_annotation_matching = strict_annotation_matching
677 self.always_qualify_annotations = always_qualify_annotations
679 # We use this to determine the end of the import block so that we can
680 # insert top-level annotations.
681 self.import_statements: List[cst.ImportFrom] = []
683 # We use this to report annotations added, as well as to determine
684 # whether to abandon the codemod in edge cases where we may have
685 # only made changes to the imports.
686 self.annotation_counts: AnnotationCounts = AnnotationCounts()
688 # We use this to collect typevars, to avoid importing existing ones from the pyi file
689 self.current_assign: Optional[cst.Assign] = None
690 self.typevars: Dict[str, cst.Assign] = {}
692 # Global variables and classes defined on the toplevel of the target module.
693 # Used to help determine which names we need to check are in scope, and add
694 # quotations to avoid undefined forward references in type annotations.
695 self.global_names: Set[str] = set()
697 # We use this to avoid annotating multiple assignments to the same
698 # symbol in a given scope
699 self.already_annotated: Set[str] = set()
701 @staticmethod
702 def store_stub_in_context(
703 context: CodemodContext,
704 stub: cst.Module,
705 overwrite_existing_annotations: bool = False,
706 use_future_annotations: bool = False,
707 strict_posargs_matching: bool = True,
708 strict_annotation_matching: bool = False,
709 always_qualify_annotations: bool = False,
710 ) -> None:
711 """
712 Store a stub module in the :class:`~libcst.codemod.CodemodContext` so
713 that type annotations from the stub can be applied in a later
714 invocation of this class.
716 If the ``overwrite_existing_annotations`` flag is ``True``, the
717 codemod will overwrite any existing annotations.
719 If you call this function multiple times, only the last values of
720 ``stub`` and ``overwrite_existing_annotations`` will take effect.
721 """
722 context.scratch[ApplyTypeAnnotationsVisitor.CONTEXT_KEY] = (
723 stub,
724 overwrite_existing_annotations,
725 use_future_annotations,
726 strict_posargs_matching,
727 strict_annotation_matching,
728 always_qualify_annotations,
729 )
731 def transform_module_impl(
732 self,
733 tree: cst.Module,
734 ) -> cst.Module:
735 """
736 Collect type annotations from all stubs and apply them to ``tree``.
738 Gather existing imports from ``tree`` so that we don't add duplicate imports.
740 Gather global names from ``tree`` so forward references are quoted.
741 """
742 import_gatherer = GatherImportsVisitor(CodemodContext())
743 tree.visit(import_gatherer)
744 existing_import_names = _get_imported_names(import_gatherer.all_imports)
746 global_names_gatherer = GatherGlobalNamesVisitor(CodemodContext())
747 tree.visit(global_names_gatherer)
748 self.global_names = global_names_gatherer.global_names.union(
749 global_names_gatherer.class_names
750 )
752 context_contents = self.context.scratch.get(
753 ApplyTypeAnnotationsVisitor.CONTEXT_KEY
754 )
755 if context_contents is not None:
756 (
757 stub,
758 overwrite_existing_annotations,
759 use_future_annotations,
760 strict_posargs_matching,
761 strict_annotation_matching,
762 always_qualify_annotations,
763 ) = context_contents
764 self.overwrite_existing_annotations = (
765 self.overwrite_existing_annotations or overwrite_existing_annotations
766 )
767 self.use_future_annotations = (
768 self.use_future_annotations or use_future_annotations
769 )
770 self.strict_posargs_matching = (
771 self.strict_posargs_matching and strict_posargs_matching
772 )
773 self.strict_annotation_matching = (
774 self.strict_annotation_matching or strict_annotation_matching
775 )
776 self.always_qualify_annotations = (
777 self.always_qualify_annotations or always_qualify_annotations
778 )
779 module_imports = self._get_module_imports(stub, import_gatherer)
780 visitor = TypeCollector(existing_import_names, module_imports, self.context)
781 cst.MetadataWrapper(stub).visit(visitor)
782 self.annotations.update(visitor.annotations)
784 if self.use_future_annotations:
785 AddImportsVisitor.add_needed_import(
786 self.context, "__future__", "annotations"
787 )
788 tree_with_imports = AddImportsVisitor(self.context).transform_module(tree)
790 tree_with_changes = tree_with_imports.visit(self)
792 # don't modify the imports if we didn't actually add any type information
793 if self.annotation_counts.any_changes_applied():
794 return tree_with_changes
795 else:
796 return tree
798 # helpers for collecting type information from the stub files
800 def _get_module_imports( # noqa: C901: too complex
801 self, stub: cst.Module, existing_import_gatherer: GatherImportsVisitor
802 ) -> Dict[str, ImportItem]:
803 """Returns a dict of modules that need to be imported to qualify symbols."""
804 # We correlate all imported symbols, e.g. foo.bar.Baz, with a list of module
805 # and from imports. If the same unqualified symbol is used from different
806 # modules, we give preference to an explicit from-import if any, and qualify
807 # everything else by importing the module.
808 #
809 # e.g. the following stub:
810 # import foo as quux
811 # from bar import Baz as X
812 # def f(x: X) -> quux.X: ...
813 # will return {'foo': ImportItem("foo", "quux")}. When the apply type
814 # annotation visitor hits `quux.X` it will retrieve the canonical name
815 # `foo.X` and then note that `foo` is in the module imports map, so it will
816 # leave the symbol qualified.
817 import_gatherer = GatherImportsVisitor(CodemodContext())
818 stub.visit(import_gatherer)
819 symbol_map = import_gatherer.symbol_mapping
820 existing_import_names = _get_imported_names(
821 existing_import_gatherer.all_imports
822 )
823 symbol_collector = ImportedSymbolCollector(existing_import_names, self.context)
824 cst.MetadataWrapper(stub).visit(symbol_collector)
825 module_imports = {}
826 for sym, imported_symbols in symbol_collector.imported_symbols.items():
827 existing = existing_import_gatherer.symbol_mapping.get(sym)
828 if existing and any(
829 s.module_name != existing.module_name for s in imported_symbols
830 ):
831 # If a symbol is imported in the main file, we have to qualify
832 # it when imported from a different module in the stub file.
833 used = True
834 elif len(imported_symbols) == 1 and not self.always_qualify_annotations:
835 # If we have a single use of a new symbol we can from-import it
836 continue
837 else:
838 # There are multiple occurrences in the stub file and none in
839 # the main file. At least one can be from-imported.
840 used = False
841 for imp_sym in imported_symbols:
842 if not imp_sym.symbol:
843 continue
844 imp = symbol_map.get(imp_sym.symbol)
845 if self.always_qualify_annotations and sym not in existing_import_names:
846 # Override 'always qualify' if this is a typing import, or
847 # the main file explicitly from-imports a symbol.
848 if imp and imp.module_name != "typing":
849 module_imports[imp.module_name] = imp
850 else:
851 imp = symbol_map.get(imp_sym.module_symbol)
852 if imp:
853 module_imports[imp.module_name] = imp
854 elif not used and imp and imp.module_name == imp_sym.module_name:
855 # We can only import a symbol directly once.
856 used = True
857 elif sym in existing_import_names:
858 if imp:
859 module_imports[imp.module_name] = imp
860 else:
861 imp = symbol_map.get(imp_sym.module_symbol)
862 if imp:
863 # imp will be None in corner cases like
864 # import foo.bar as Baz
865 # x: Baz
866 # which is technically valid python but nonsensical as a
867 # type annotation. Dropping it on the floor for now.
868 module_imports[imp.module_name] = imp
869 return module_imports
871 # helpers for processing annotation nodes
872 def _quote_future_annotations(self, annotation: cst.Annotation) -> cst.Annotation:
873 # TODO: We probably want to make sure references to classes defined in the current
874 # module come to us fully qualified - so we can do the dequalification here and
875 # know to look for what is in-scope without also catching builtins like "None" in the
876 # quoting. This should probably also be extended to handle what imports are in scope,
877 # as well as subscriptable types.
878 # Note: We are collecting all imports and passing this to the type collector grabbing
879 # annotations from the stub file; should consolidate import handling somewhere too.
880 node = annotation.annotation
881 if (
882 isinstance(node, cst.Name)
883 and (node.value in self.global_names)
884 and not (node.value in self.visited_classes)
885 ):
886 return annotation.with_changes(
887 annotation=cst.SimpleString(value=f'"{node.value}"')
888 )
889 return annotation
891 # smart constructors: all applied annotations happen via one of these
893 def _apply_annotation_to_attribute_or_global(
894 self,
895 name: str,
896 annotation: cst.Annotation,
897 value: Optional[cst.BaseExpression],
898 ) -> cst.AnnAssign:
899 if len(self.qualifier) == 0:
900 self.annotation_counts.global_annotations += 1
901 else:
902 self.annotation_counts.attribute_annotations += 1
903 return cst.AnnAssign(
904 cst.Name(name),
905 self._quote_future_annotations(annotation),
906 value,
907 )
909 def _apply_annotation_to_parameter(
910 self,
911 parameter: cst.Param,
912 annotation: cst.Annotation,
913 ) -> cst.Param:
914 self.annotation_counts.parameter_annotations += 1
915 return parameter.with_changes(
916 annotation=self._quote_future_annotations(annotation),
917 )
919 def _apply_annotation_to_return(
920 self,
921 function_def: cst.FunctionDef,
922 annotation: cst.Annotation,
923 ) -> cst.FunctionDef:
924 self.annotation_counts.return_annotations += 1
925 return function_def.with_changes(
926 returns=self._quote_future_annotations(annotation),
927 )
929 # private methods used in the visit and leave methods
931 def _qualifier_name(self) -> str:
932 return ".".join(self.qualifier)
934 def _annotate_single_target(
935 self,
936 node: cst.Assign,
937 updated_node: cst.Assign,
938 ) -> Union[cst.Assign, cst.AnnAssign]:
939 only_target = node.targets[0].target
940 if isinstance(only_target, (cst.Tuple, cst.List)):
941 for element in only_target.elements:
942 value = element.value
943 name = get_full_name_for_node(value)
944 if name is not None and name != "_":
945 self._add_to_toplevel_annotations(name)
946 elif isinstance(only_target, (cst.Subscript)):
947 pass
948 else:
949 name = get_full_name_for_node(only_target)
950 if name is not None:
951 self.qualifier.append(name)
952 qualifier_name = self._qualifier_name()
953 if qualifier_name in self.annotations.attributes and not isinstance(
954 only_target, (cst.Attribute, cst.Subscript)
955 ):
956 if qualifier_name not in self.already_annotated:
957 self.already_annotated.add(qualifier_name)
958 annotation = self.annotations.attributes[qualifier_name]
959 self.qualifier.pop()
960 return self._apply_annotation_to_attribute_or_global(
961 name=name,
962 annotation=annotation,
963 value=node.value,
964 )
965 else:
966 self.qualifier.pop()
967 return updated_node
969 def _split_module(
970 self,
971 module: cst.Module,
972 updated_module: cst.Module,
973 ) -> Tuple[
974 List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]],
975 List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]],
976 ]:
977 import_add_location = 0
978 # This works under the principle that while we might modify node contents,
979 # we have yet to modify the number of statements. So we can match on the
980 # original tree but break up the statements of the modified tree. If we
981 # change this assumption in this visitor, we will have to change this code.
982 for i, statement in enumerate(module.body):
983 if isinstance(statement, cst.SimpleStatementLine):
984 for possible_import in statement.body:
985 for last_import in self.import_statements:
986 if possible_import is last_import:
987 import_add_location = i + 1
988 break
990 return (
991 list(updated_module.body[:import_add_location]),
992 list(updated_module.body[import_add_location:]),
993 )
995 def _add_to_toplevel_annotations(
996 self,
997 name: str,
998 ) -> None:
999 self.qualifier.append(name)
1000 if self._qualifier_name() in self.annotations.attributes:
1001 annotation = self.annotations.attributes[self._qualifier_name()]
1002 self.toplevel_annotations[name] = annotation
1003 self.qualifier.pop()
1005 def _update_parameters(
1006 self,
1007 annotations: FunctionAnnotation,
1008 updated_node: cst.FunctionDef,
1009 ) -> cst.Parameters:
1010 # Update params and default params with annotations
1011 # Don't override existing annotations or default values unless asked
1012 # to overwrite existing annotations.
1013 def update_annotation(
1014 parameters: Sequence[cst.Param],
1015 annotations: Sequence[cst.Param],
1016 positional: bool,
1017 ) -> List[cst.Param]:
1018 parameter_annotations = {}
1019 annotated_parameters = []
1020 positional = positional and not self.strict_posargs_matching
1021 for i, parameter in enumerate(annotations):
1022 key = i if positional else parameter.name.value
1023 if parameter.annotation:
1024 parameter_annotations[key] = parameter.annotation.with_changes(
1025 whitespace_before_indicator=cst.SimpleWhitespace(value="")
1026 )
1027 for i, parameter in enumerate(parameters):
1028 key = i if positional else parameter.name.value
1029 if key in parameter_annotations and (
1030 self.overwrite_existing_annotations or not parameter.annotation
1031 ):
1032 parameter = self._apply_annotation_to_parameter(
1033 parameter=parameter,
1034 annotation=parameter_annotations[key],
1035 )
1036 annotated_parameters.append(parameter)
1037 return annotated_parameters
1039 return updated_node.params.with_changes(
1040 params=update_annotation(
1041 updated_node.params.params,
1042 annotations.parameters.params,
1043 positional=True,
1044 ),
1045 kwonly_params=update_annotation(
1046 updated_node.params.kwonly_params,
1047 annotations.parameters.kwonly_params,
1048 positional=False,
1049 ),
1050 posonly_params=update_annotation(
1051 updated_node.params.posonly_params,
1052 annotations.parameters.posonly_params,
1053 positional=True,
1054 ),
1055 )
1057 def _insert_empty_line(
1058 self,
1059 statements: List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]],
1060 ) -> List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]]:
1061 if len(statements) < 1:
1062 # No statements, nothing to add to
1063 return statements
1064 if len(statements[0].leading_lines) == 0:
1065 # Statement has no leading lines, add one!
1066 return [
1067 statements[0].with_changes(leading_lines=(cst.EmptyLine(),)),
1068 *statements[1:],
1069 ]
1070 if statements[0].leading_lines[0].comment is None:
1071 # First line is empty, so its safe to leave as-is
1072 return statements
1073 # Statement has a comment first line, so lets add one more empty line
1074 return [
1075 statements[0].with_changes(
1076 leading_lines=(cst.EmptyLine(), *statements[0].leading_lines)
1077 ),
1078 *statements[1:],
1079 ]
1081 def _match_signatures( # noqa: C901: Too complex
1082 self,
1083 function: cst.FunctionDef,
1084 annotations: FunctionAnnotation,
1085 ) -> bool:
1086 """Check that function annotations on both signatures are compatible."""
1088 def compatible(
1089 p: Optional[cst.Annotation],
1090 q: Optional[cst.Annotation],
1091 ) -> bool:
1092 if (
1093 self.overwrite_existing_annotations
1094 or not _is_non_sentinel(p)
1095 or not _is_non_sentinel(q)
1096 ):
1097 return True
1098 if not self.strict_annotation_matching:
1099 # We will not overwrite clashing annotations, but the signature as a
1100 # whole will be marked compatible so that holes can be filled in.
1101 return True
1102 return p.annotation.deep_equals(q.annotation) # pyre-ignore[16]
1104 def match_posargs(
1105 ps: Sequence[cst.Param],
1106 qs: Sequence[cst.Param],
1107 ) -> bool:
1108 if len(ps) != len(qs):
1109 return False
1110 for p, q in zip(ps, qs):
1111 if self.strict_posargs_matching and not p.name.value == q.name.value:
1112 return False
1113 if not compatible(p.annotation, q.annotation):
1114 return False
1115 return True
1117 def match_kwargs(
1118 ps: Sequence[cst.Param],
1119 qs: Sequence[cst.Param],
1120 ) -> bool:
1121 ps_dict = {x.name.value: x for x in ps}
1122 qs_dict = {x.name.value: x for x in qs}
1123 if set(ps_dict.keys()) != set(qs_dict.keys()):
1124 return False
1125 for k in ps_dict.keys():
1126 if not compatible(ps_dict[k].annotation, qs_dict[k].annotation):
1127 return False
1128 return True
1130 def match_star(
1131 p: StarParamType,
1132 q: StarParamType,
1133 ) -> bool:
1134 return _is_non_sentinel(p) == _is_non_sentinel(q)
1136 def match_params(
1137 f: cst.FunctionDef,
1138 g: FunctionAnnotation,
1139 ) -> bool:
1140 p, q = f.params, g.parameters
1141 return (
1142 match_posargs(p.params, q.params)
1143 and match_posargs(p.posonly_params, q.posonly_params)
1144 and match_kwargs(p.kwonly_params, q.kwonly_params)
1145 and match_star(p.star_arg, q.star_arg)
1146 and match_star(p.star_kwarg, q.star_kwarg)
1147 )
1149 def match_return(
1150 f: cst.FunctionDef,
1151 g: FunctionAnnotation,
1152 ) -> bool:
1153 return compatible(f.returns, g.returns)
1155 return match_params(function, annotations) and match_return(
1156 function, annotations
1157 )
1159 # transform API methods
1161 def visit_ClassDef(
1162 self,
1163 node: cst.ClassDef,
1164 ) -> None:
1165 self.qualifier.append(node.name.value)
1167 def leave_ClassDef(
1168 self,
1169 original_node: cst.ClassDef,
1170 updated_node: cst.ClassDef,
1171 ) -> cst.ClassDef:
1172 self.visited_classes.add(original_node.name.value)
1173 cls_name = ".".join(self.qualifier)
1174 self.qualifier.pop()
1175 definition = self.annotations.class_definitions.get(cls_name)
1176 if definition:
1177 b1 = _find_generic_base(definition)
1178 b2 = _find_generic_base(updated_node)
1179 if b1 and not b2:
1180 new_bases = list(updated_node.bases) + [b1]
1181 self.annotation_counts.typevars_and_generics_added += 1
1182 return updated_node.with_changes(bases=new_bases)
1183 return updated_node
1185 def visit_FunctionDef(
1186 self,
1187 node: cst.FunctionDef,
1188 ) -> bool:
1189 self.qualifier.append(node.name.value)
1190 # pyi files don't support inner functions, return False to stop the traversal.
1191 return False
1193 def leave_FunctionDef(
1194 self,
1195 original_node: cst.FunctionDef,
1196 updated_node: cst.FunctionDef,
1197 ) -> cst.FunctionDef:
1198 key = FunctionKey.make(self._qualifier_name(), updated_node.params)
1199 self.qualifier.pop()
1200 if key in self.annotations.functions:
1201 function_annotation = self.annotations.functions[key]
1202 # Only add new annotation if:
1203 # * we have matching function signatures and
1204 # * we are explicitly told to overwrite existing annotations or
1205 # * there is no existing annotation
1206 if not self._match_signatures(updated_node, function_annotation):
1207 return updated_node
1208 set_return_annotation = (
1209 self.overwrite_existing_annotations or updated_node.returns is None
1210 )
1211 if set_return_annotation and function_annotation.returns is not None:
1212 updated_node = self._apply_annotation_to_return(
1213 function_def=updated_node,
1214 annotation=function_annotation.returns,
1215 )
1216 # Don't override default values when annotating functions
1217 new_parameters = self._update_parameters(function_annotation, updated_node)
1218 return updated_node.with_changes(params=new_parameters)
1219 return updated_node
1221 def visit_Assign(
1222 self,
1223 node: cst.Assign,
1224 ) -> None:
1225 self.current_assign = node
1227 @m.call_if_inside(m.Assign())
1228 @m.visit(m.Call(func=m.Name("TypeVar")))
1229 def record_typevar(
1230 self,
1231 node: cst.Call,
1232 ) -> None:
1233 # pyre-ignore current_assign is never None here
1234 name = get_full_name_for_node(self.current_assign.targets[0].target)
1235 if name is not None:
1236 # Preserve the whole node, even though we currently just use the
1237 # name, so that we can match bounds and variance at some point and
1238 # determine if two typevars with the same name are indeed the same.
1240 # pyre-ignore current_assign is never None here
1241 self.typevars[name] = self.current_assign
1242 self.current_assign = None
1244 def leave_Assign(
1245 self,
1246 original_node: cst.Assign,
1247 updated_node: cst.Assign,
1248 ) -> Union[cst.Assign, cst.AnnAssign]:
1249 self.current_assign = None
1251 if len(original_node.targets) > 1:
1252 for assign in original_node.targets:
1253 target = assign.target
1254 if isinstance(target, (cst.Name, cst.Attribute)):
1255 name = get_full_name_for_node(target)
1256 if name is not None and name != "_":
1257 # Add separate top-level annotations for `a = b = 1`
1258 # as `a: int` and `b: int`.
1259 self._add_to_toplevel_annotations(name)
1260 return updated_node
1261 else:
1262 return self._annotate_single_target(original_node, updated_node)
1264 def leave_ImportFrom(
1265 self,
1266 original_node: cst.ImportFrom,
1267 updated_node: cst.ImportFrom,
1268 ) -> cst.ImportFrom:
1269 self.import_statements.append(original_node)
1270 return updated_node
1272 def leave_Module(
1273 self,
1274 original_node: cst.Module,
1275 updated_node: cst.Module,
1276 ) -> cst.Module:
1277 fresh_class_definitions = [
1278 definition
1279 for name, definition in self.annotations.class_definitions.items()
1280 if name not in self.visited_classes
1281 ]
1283 # NOTE: The entire change will also be abandoned if
1284 # self.annotation_counts is all 0s, so if adding any new category make
1285 # sure to record it there.
1286 if not (
1287 self.toplevel_annotations
1288 or fresh_class_definitions
1289 or self.annotations.typevars
1290 ):
1291 return updated_node
1293 toplevel_statements = []
1294 # First, find the insertion point for imports
1295 statements_before_imports, statements_after_imports = self._split_module(
1296 original_node, updated_node
1297 )
1299 # Make sure there's at least one empty line before the first non-import
1300 statements_after_imports = self._insert_empty_line(statements_after_imports)
1302 for name, annotation in self.toplevel_annotations.items():
1303 annotated_assign = self._apply_annotation_to_attribute_or_global(
1304 name=name,
1305 annotation=annotation,
1306 value=None,
1307 )
1308 toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))
1310 # TypeVar definitions could be scattered through the file, so do not
1311 # attempt to put new ones with existing ones, just add them at the top.
1312 typevars = {
1313 k: v for k, v in self.annotations.typevars.items() if k not in self.typevars
1314 }
1315 if typevars:
1316 for var, stmt in typevars.items():
1317 toplevel_statements.append(cst.Newline())
1318 toplevel_statements.append(stmt)
1319 self.annotation_counts.typevars_and_generics_added += 1
1320 toplevel_statements.append(cst.Newline())
1322 self.annotation_counts.classes_added = len(fresh_class_definitions)
1323 toplevel_statements.extend(fresh_class_definitions)
1325 return updated_node.with_changes(
1326 body=[
1327 *statements_before_imports,
1328 *toplevel_statements,
1329 *statements_after_imports,
1330 ]
1331 )