Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/libcst/codemod/visitors/_apply_type_annotations.py: 23%

566 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:43 +0000

1# Copyright (c) Meta Platforms, Inc. and affiliates. 

2# 

3# This source code is licensed under the MIT license found in the 

4# LICENSE file in the root directory of this source tree. 

5 

6from collections import defaultdict 

7from dataclasses import dataclass 

8from typing import Dict, List, Optional, Sequence, Set, Tuple, Union 

9 

10import libcst as cst 

11import libcst.matchers as m 

12 

13from libcst.codemod._context import CodemodContext 

14from libcst.codemod._visitor import ContextAwareTransformer 

15from libcst.codemod.visitors._add_imports import AddImportsVisitor 

16from libcst.codemod.visitors._gather_global_names import GatherGlobalNamesVisitor 

17from libcst.codemod.visitors._gather_imports import GatherImportsVisitor 

18from libcst.codemod.visitors._imports import ImportItem 

19from libcst.helpers import get_full_name_for_node 

20from libcst.metadata import PositionProvider, QualifiedNameProvider 

21 

22 

23NameOrAttribute = Union[cst.Name, cst.Attribute] 

24NAME_OR_ATTRIBUTE = (cst.Name, cst.Attribute) 

25# Union type for *args and **args 

26StarParamType = Union[ 

27 None, 

28 cst._maybe_sentinel.MaybeSentinel, 

29 cst._nodes.expression.Param, 

30 cst._nodes.expression.ParamStar, 

31] 

32 

33 

34def _module_and_target(qualified_name: str) -> Tuple[str, str]: 

35 relative_prefix = "" 

36 while qualified_name.startswith("."): 

37 relative_prefix += "." 

38 qualified_name = qualified_name[1:] 

39 split = qualified_name.rsplit(".", 1) 

40 if len(split) == 1: 

41 qualifier, target = "", split[0] 

42 else: 

43 qualifier, target = split 

44 return (relative_prefix + qualifier, target) 

45 

46 

47def _get_unique_qualified_name( 

48 visitor: m.MatcherDecoratableVisitor, node: cst.CSTNode 

49) -> str: 

50 name = None 

51 names = [q.name for q in visitor.get_metadata(QualifiedNameProvider, node)] 

52 if len(names) == 0: 

53 # we hit this branch if the stub is directly using a fully 

54 # qualified name, which is not technically valid python but is 

55 # convenient to allow. 

56 name = get_full_name_for_node(node) 

57 elif len(names) == 1 and isinstance(names[0], str): 

58 name = names[0] 

59 if name is None: 

60 start = visitor.get_metadata(PositionProvider, node).start 

61 raise ValueError( 

62 "Could not resolve a unique qualified name for type " 

63 + f"{get_full_name_for_node(node)} at {start.line}:{start.column}. " 

64 + f"Candidate names were: {names!r}" 

65 ) 

66 return name 

67 

68 

69def _get_import_alias_names( 

70 import_aliases: Sequence[cst.ImportAlias], 

71) -> Set[str]: 

72 import_names = set() 

73 for imported_name in import_aliases: 

74 asname = imported_name.asname 

75 if asname is not None: 

76 import_names.add(get_full_name_for_node(asname.name)) 

77 else: 

78 import_names.add(get_full_name_for_node(imported_name.name)) 

79 return import_names 

80 

81 

82def _get_imported_names( 

83 imports: Sequence[Union[cst.Import, cst.ImportFrom]], 

84) -> Set[str]: 

85 """ 

86 Given a series of import statements (both Import and ImportFrom), 

87 determine all of the names that have been imported into the current 

88 scope. For example: 

89 - ``import foo.bar as bar, foo.baz`` produces ``{'bar', 'foo.baz'}`` 

90 - ``from foo import (Bar, Baz as B)`` produces ``{'Bar', 'B'}`` 

91 - ``from foo import *`` produces ``set()` because we cannot resolve names 

92 """ 

93 import_names = set() 

94 for _import in imports: 

95 if isinstance(_import, cst.Import): 

96 import_names.update(_get_import_alias_names(_import.names)) 

97 else: 

98 names = _import.names 

99 if not isinstance(names, cst.ImportStar): 

100 import_names.update(_get_import_alias_names(names)) 

101 return import_names 

102 

103 

104def _is_non_sentinel( 

105 x: Union[None, cst.CSTNode, cst.MaybeSentinel], 

106) -> bool: 

107 return x is not None and x != cst.MaybeSentinel.DEFAULT 

108 

109 

110def _get_string_value( 

111 node: cst.SimpleString, 

112) -> str: 

113 s = node.value 

114 c = s[-1] 

115 return s[s.index(c) : -1] 

116 

117 

118def _find_generic_base( 

119 node: cst.ClassDef, 

120) -> Optional[cst.Arg]: 

121 for b in node.bases: 

122 if m.matches(b.value, m.Subscript(value=m.Name("Generic"))): 

123 return b 

124 

125 

126@dataclass(frozen=True) 

127class FunctionKey: 

128 """ 

129 Class representing a funciton name and signature. 

130 

131 This exists to ensure we do not attempt to apply stubs to functions whose 

132 definition is incompatible. 

133 """ 

134 

135 name: str 

136 pos: int 

137 kwonly: str 

138 posonly: int 

139 star_arg: bool 

140 star_kwarg: bool 

141 

142 @classmethod 

143 def make( 

144 cls, 

145 name: str, 

146 params: cst.Parameters, 

147 ) -> "FunctionKey": 

148 pos = len(params.params) 

149 kwonly = ",".join(sorted(x.name.value for x in params.kwonly_params)) 

150 posonly = len(params.posonly_params) 

151 star_arg = _is_non_sentinel(params.star_arg) 

152 star_kwarg = _is_non_sentinel(params.star_kwarg) 

153 return cls( 

154 name, 

155 pos, 

156 kwonly, 

157 posonly, 

158 star_arg, 

159 star_kwarg, 

160 ) 

161 

162 

163@dataclass(frozen=True) 

164class FunctionAnnotation: 

165 parameters: cst.Parameters 

166 returns: Optional[cst.Annotation] 

167 

168 

169@dataclass 

170class Annotations: 

171 """ 

172 Represents all of the annotation information we might add to 

173 a class: 

174 - All data is keyed on the qualified name relative to the module root 

175 - The ``functions`` field also keys on the signature so that we 

176 do not apply stub types where the signature is incompatible. 

177 

178 The idea is that 

179 - ``functions`` contains all function and method type 

180 information from the stub, and the qualifier for a method includes 

181 the containing class names (e.g. "Cat.meow") 

182 - ``attributes`` similarly contains all globals 

183 and class-level attribute type information. 

184 - The ``class_definitions`` field contains all of the classes 

185 defined in the stub. Most of these classes will be ignored in 

186 downstream logic (it is *not* used to annotate attributes or 

187 method), but there are some cases like TypedDict where a 

188 typing-only class needs to be injected. 

189 - The field ``typevars`` contains the assign statement for all 

190 type variables in the stub, and ``names`` tracks 

191 all of the names used in annotations; together these fields 

192 tell us which typevars should be included in the codemod 

193 (all typevars that appear in annotations.) 

194 """ 

195 

196 # TODO: consider simplifying this in a few ways: 

197 # - We could probably just inject all typevars, used or not. 

198 # It doesn't seem to me that our codemod needs to act like 

199 # a linter checking for unused names. 

200 # - We could probably decide which classes are typing-only 

201 # in the visitor rather than the codemod, which would make 

202 # it easier to reason locally about (and document) how the 

203 # class_definitions field works. 

204 

205 functions: Dict[FunctionKey, FunctionAnnotation] 

206 attributes: Dict[str, cst.Annotation] 

207 class_definitions: Dict[str, cst.ClassDef] 

208 typevars: Dict[str, cst.Assign] 

209 names: Set[str] 

210 

211 @classmethod 

212 def empty(cls) -> "Annotations": 

213 return Annotations({}, {}, {}, {}, set()) 

214 

215 def update(self, other: "Annotations") -> None: 

216 self.functions.update(other.functions) 

217 self.attributes.update(other.attributes) 

218 self.class_definitions.update(other.class_definitions) 

219 self.typevars.update(other.typevars) 

220 self.names.update(other.names) 

221 

222 def finish(self) -> None: 

223 self.typevars = {k: v for k, v in self.typevars.items() if k in self.names} 

224 

225 

226@dataclass(frozen=True) 

227class ImportedSymbol: 

228 """Import of foo.Bar, where both foo and Bar are potentially aliases.""" 

229 

230 module_name: str 

231 module_alias: Optional[str] = None 

232 target_name: Optional[str] = None 

233 target_alias: Optional[str] = None 

234 

235 @property 

236 def symbol(self) -> Optional[str]: 

237 return self.target_alias or self.target_name 

238 

239 @property 

240 def module_symbol(self) -> str: 

241 return self.module_alias or self.module_name 

242 

243 

244class ImportedSymbolCollector(m.MatcherDecoratableVisitor): 

245 """ 

246 Collect imported symbols from a stub module. 

247 """ 

248 

249 METADATA_DEPENDENCIES = ( 

250 PositionProvider, 

251 QualifiedNameProvider, 

252 ) 

253 

254 def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None: 

255 super().__init__() 

256 self.existing_imports: Set[str] = existing_imports 

257 self.imported_symbols: Dict[str, Set[ImportedSymbol]] = defaultdict(set) 

258 self.in_annotation: bool = False 

259 

260 def visit_Annotation(self, node: cst.Annotation) -> None: 

261 self.in_annotation = True 

262 

263 def leave_Annotation(self, original_node: cst.Annotation) -> None: 

264 self.in_annotation = False 

265 

266 def visit_ClassDef(self, node: cst.ClassDef) -> None: 

267 for base in node.bases: 

268 value = base.value 

269 if isinstance(value, NAME_OR_ATTRIBUTE): 

270 self._handle_NameOrAttribute(value) 

271 

272 def visit_Name(self, node: cst.Name) -> None: 

273 if self.in_annotation: 

274 self._handle_NameOrAttribute(node) 

275 

276 def visit_Attribute(self, node: cst.Attribute) -> None: 

277 if self.in_annotation: 

278 self._handle_NameOrAttribute(node) 

279 

280 def visit_Subscript(self, node: cst.Subscript) -> bool: 

281 if isinstance(node.value, NAME_OR_ATTRIBUTE): 

282 return True 

283 return _get_unique_qualified_name(self, node) not in ("Type", "typing.Type") 

284 

285 def _handle_NameOrAttribute( 

286 self, 

287 node: NameOrAttribute, 

288 ) -> None: 

289 # Adds the qualified name to the list of imported symbols 

290 obj = sym = None # keep pyre happy 

291 if isinstance(node, cst.Name): 

292 obj = None 

293 sym = node.value 

294 elif isinstance(node, cst.Attribute): 

295 obj = node.value.value # pyre-ignore[16] 

296 sym = node.attr.value 

297 qualified_name = _get_unique_qualified_name(self, node) 

298 module, target = _module_and_target(qualified_name) 

299 if module in ("", "builtins"): 

300 return 

301 elif qualified_name not in self.existing_imports: 

302 mod = ImportedSymbol( 

303 module_name=module, 

304 module_alias=obj if obj != module else None, 

305 target_name=target, 

306 target_alias=sym if sym != target else None, 

307 ) 

308 self.imported_symbols[sym].add(mod) 

309 

310 

311class TypeCollector(m.MatcherDecoratableVisitor): 

312 """ 

313 Collect type annotations from a stub module. 

314 """ 

315 

316 METADATA_DEPENDENCIES = ( 

317 PositionProvider, 

318 QualifiedNameProvider, 

319 ) 

320 

321 annotations: Annotations 

322 

323 def __init__( 

324 self, 

325 existing_imports: Set[str], 

326 module_imports: Dict[str, ImportItem], 

327 context: CodemodContext, 

328 ) -> None: 

329 super().__init__() 

330 self.context = context 

331 # Existing imports, determined by looking at the target module. 

332 # Used to help us determine when a type in a stub will require new imports. 

333 # 

334 # The contents of this are fully-qualified names of types in scope 

335 # as well as module names, although downstream we effectively ignore 

336 # the module names as of the current implementation. 

337 self.existing_imports: Set[str] = existing_imports 

338 # Module imports, gathered by prescanning the stub file to determine 

339 # which modules need to be imported directly to qualify their symbols. 

340 self.module_imports: Dict[str, ImportItem] = module_imports 

341 # Fields that help us track temporary state as we recurse 

342 self.qualifier: List[str] = [] 

343 self.current_assign: Optional[cst.Assign] = None # used to collect typevars 

344 # Store the annotations. 

345 self.annotations = Annotations.empty() 

346 

347 def visit_ClassDef( 

348 self, 

349 node: cst.ClassDef, 

350 ) -> None: 

351 self.qualifier.append(node.name.value) 

352 new_bases = [] 

353 for base in node.bases: 

354 value = base.value 

355 if isinstance(value, NAME_OR_ATTRIBUTE): 

356 new_value = value.visit(_TypeCollectorDequalifier(self)) 

357 elif isinstance(value, cst.Subscript): 

358 new_value = value.visit(_TypeCollectorDequalifier(self)) 

359 else: 

360 start = self.get_metadata(PositionProvider, node).start 

361 raise ValueError( 

362 "Invalid type used as base class in stub file at " 

363 + f"{start.line}:{start.column}. Only subscripts, names, and " 

364 + "attributes are valid base classes for static typing." 

365 ) 

366 new_bases.append(base.with_changes(value=new_value)) 

367 

368 self.annotations.class_definitions[node.name.value] = node.with_changes( 

369 bases=new_bases 

370 ) 

371 

372 def leave_ClassDef( 

373 self, 

374 original_node: cst.ClassDef, 

375 ) -> None: 

376 self.qualifier.pop() 

377 

378 def visit_FunctionDef( 

379 self, 

380 node: cst.FunctionDef, 

381 ) -> bool: 

382 self.qualifier.append(node.name.value) 

383 returns = node.returns 

384 return_annotation = ( 

385 returns.visit(_TypeCollectorDequalifier(self)) 

386 if returns is not None 

387 else None 

388 ) 

389 assert return_annotation is None or isinstance( 

390 return_annotation, cst.Annotation 

391 ) 

392 parameter_annotations = self._handle_Parameters(node.params) 

393 name = ".".join(self.qualifier) 

394 key = FunctionKey.make(name, node.params) 

395 self.annotations.functions[key] = FunctionAnnotation( 

396 parameters=parameter_annotations, returns=return_annotation 

397 ) 

398 

399 # pyi files don't support inner functions, return False to stop the traversal. 

400 return False 

401 

402 def leave_FunctionDef( 

403 self, 

404 original_node: cst.FunctionDef, 

405 ) -> None: 

406 self.qualifier.pop() 

407 

408 def visit_AnnAssign( 

409 self, 

410 node: cst.AnnAssign, 

411 ) -> bool: 

412 name = get_full_name_for_node(node.target) 

413 if name is not None: 

414 self.qualifier.append(name) 

415 annotation_value = node.annotation.visit(_TypeCollectorDequalifier(self)) 

416 assert isinstance(annotation_value, cst.Annotation) 

417 self.annotations.attributes[".".join(self.qualifier)] = annotation_value 

418 return True 

419 

420 def leave_AnnAssign( 

421 self, 

422 original_node: cst.AnnAssign, 

423 ) -> None: 

424 self.qualifier.pop() 

425 

426 def visit_Assign( 

427 self, 

428 node: cst.Assign, 

429 ) -> None: 

430 self.current_assign = node 

431 

432 def leave_Assign( 

433 self, 

434 original_node: cst.Assign, 

435 ) -> None: 

436 self.current_assign = None 

437 

438 @m.call_if_inside(m.Assign()) 

439 @m.visit(m.Call(func=m.Name("TypeVar"))) 

440 def record_typevar( 

441 self, 

442 node: cst.Call, 

443 ) -> None: 

444 # pyre-ignore current_assign is never None here 

445 name = get_full_name_for_node(self.current_assign.targets[0].target) 

446 if name is not None: 

447 # pyre-ignore current_assign is never None here 

448 self.annotations.typevars[name] = self.current_assign 

449 self._handle_qualification_and_should_qualify("typing.TypeVar") 

450 self.current_assign = None 

451 

452 def leave_Module( 

453 self, 

454 original_node: cst.Module, 

455 ) -> None: 

456 self.annotations.finish() 

457 

458 def _module_and_target( 

459 self, 

460 qualified_name: str, 

461 ) -> Tuple[str, str]: 

462 relative_prefix = "" 

463 while qualified_name.startswith("."): 

464 relative_prefix += "." 

465 qualified_name = qualified_name[1:] 

466 split = qualified_name.rsplit(".", 1) 

467 if len(split) == 1: 

468 qualifier, target = "", split[0] 

469 else: 

470 qualifier, target = split 

471 return (relative_prefix + qualifier, target) 

472 

473 def _handle_qualification_and_should_qualify( 

474 self, qualified_name: str, node: Optional[cst.CSTNode] = None 

475 ) -> bool: 

476 """ 

477 Based on a qualified name and the existing module imports, record that 

478 we need to add an import if necessary and return whether or not we 

479 should use the qualified name due to a preexisting import. 

480 """ 

481 module, target = self._module_and_target(qualified_name) 

482 if module in ("", "builtins"): 

483 return False 

484 elif qualified_name not in self.existing_imports: 

485 if module in self.existing_imports: 

486 return True 

487 elif module in self.module_imports: 

488 m = self.module_imports[module] 

489 if m.obj_name is None: 

490 asname = m.alias 

491 else: 

492 asname = None 

493 AddImportsVisitor.add_needed_import( 

494 self.context, m.module_name, asname=asname 

495 ) 

496 return True 

497 else: 

498 if node and isinstance(node, cst.Name) and node.value != target: 

499 asname = node.value 

500 else: 

501 asname = None 

502 AddImportsVisitor.add_needed_import( 

503 self.context, 

504 module, 

505 target, 

506 asname=asname, 

507 ) 

508 return False 

509 return False 

510 

511 # Handler functions 

512 

513 def _handle_Parameters( 

514 self, 

515 parameters: cst.Parameters, 

516 ) -> cst.Parameters: 

517 def update_annotations( 

518 parameters: Sequence[cst.Param], 

519 ) -> List[cst.Param]: 

520 updated_parameters = [] 

521 for parameter in list(parameters): 

522 annotation = parameter.annotation 

523 if annotation is not None: 

524 parameter = parameter.with_changes( 

525 annotation=annotation.visit(_TypeCollectorDequalifier(self)) 

526 ) 

527 updated_parameters.append(parameter) 

528 return updated_parameters 

529 

530 return parameters.with_changes(params=update_annotations(parameters.params)) 

531 

532 

533class _TypeCollectorDequalifier(cst.CSTTransformer): 

534 def __init__(self, type_collector: "TypeCollector") -> None: 

535 self.type_collector = type_collector 

536 

537 def leave_Name(self, original_node: cst.Name, updated_node: cst.Name) -> cst.Name: 

538 qualified_name = _get_unique_qualified_name(self.type_collector, original_node) 

539 should_qualify = self.type_collector._handle_qualification_and_should_qualify( 

540 qualified_name, original_node 

541 ) 

542 self.type_collector.annotations.names.add(qualified_name) 

543 if should_qualify: 

544 qualified_node = cst.parse_module(qualified_name) 

545 return qualified_node # pyre-ignore[7] 

546 else: 

547 return original_node 

548 

549 def visit_Attribute(self, node: cst.Attribute) -> bool: 

550 return False 

551 

552 def leave_Attribute( 

553 self, original_node: cst.Attribute, updated_node: cst.Attribute 

554 ) -> cst.BaseExpression: 

555 qualified_name = _get_unique_qualified_name(self.type_collector, original_node) 

556 should_qualify = self.type_collector._handle_qualification_and_should_qualify( 

557 qualified_name, original_node 

558 ) 

559 self.type_collector.annotations.names.add(qualified_name) 

560 if should_qualify: 

561 return original_node 

562 else: 

563 return original_node.attr 

564 

565 def leave_Index( 

566 self, original_node: cst.Index, updated_node: cst.Index 

567 ) -> cst.Index: 

568 if isinstance(original_node.value, cst.SimpleString): 

569 self.type_collector.annotations.names.add( 

570 _get_string_value(original_node.value) 

571 ) 

572 return updated_node 

573 

574 def visit_Subscript(self, node: cst.Subscript) -> bool: 

575 return _get_unique_qualified_name(self.type_collector, node) not in ( 

576 "Type", 

577 "typing.Type", 

578 ) 

579 

580 def leave_Subscript( 

581 self, original_node: cst.Subscript, updated_node: cst.Subscript 

582 ) -> cst.Subscript: 

583 if _get_unique_qualified_name(self.type_collector, original_node) in ( 

584 "Type", 

585 "typing.Type", 

586 ): 

587 # Note: we are intentionally not handling qualification of 

588 # anything inside `Type` because it's common to have nested 

589 # classes, which we cannot currently distinguish from classes 

590 # coming from other modules, appear here. 

591 return original_node.with_changes(value=original_node.value.visit(self)) 

592 return updated_node 

593 

594 

595@dataclass 

596class AnnotationCounts: 

597 global_annotations: int = 0 

598 attribute_annotations: int = 0 

599 parameter_annotations: int = 0 

600 return_annotations: int = 0 

601 classes_added: int = 0 

602 typevars_and_generics_added: int = 0 

603 

604 def any_changes_applied(self) -> bool: 

605 return ( 

606 self.global_annotations 

607 + self.attribute_annotations 

608 + self.parameter_annotations 

609 + self.return_annotations 

610 + self.classes_added 

611 + self.typevars_and_generics_added 

612 ) > 0 

613 

614 

615class ApplyTypeAnnotationsVisitor(ContextAwareTransformer): 

616 """ 

617 Apply type annotations to a source module using the given stub mdules. 

618 You can also pass in explicit annotations for functions and attributes and 

619 pass in new class definitions that need to be added to the source module. 

620 

621 This is one of the transforms that is available automatically to you when 

622 running a codemod. To use it in this manner, import 

623 :class:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor` and then call 

624 the static 

625 :meth:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor.store_stub_in_context` 

626 method, giving it the current context (found as ``self.context`` for all 

627 subclasses of :class:`~libcst.codemod.Codemod`), the stub module from which 

628 you wish to add annotations. 

629 

630 For example, you can store the type annotation ``int`` for ``x`` using:: 

631 

632 stub_module = parse_module("x: int = ...") 

633 

634 ApplyTypeAnnotationsVisitor.store_stub_in_context(self.context, stub_module) 

635 

636 You can apply the type annotation using:: 

637 

638 source_module = parse_module("x = 1") 

639 ApplyTypeAnnotationsVisitor.transform_module(source_module) 

640 

641 This will produce the following code:: 

642 

643 x: int = 1 

644 

645 If the function or attribute already has a type annotation, it will not be 

646 overwritten. 

647 

648 To overwrite existing annotations when applying annotations from a stub, 

649 use the keyword argument ``overwrite_existing_annotations=True`` when 

650 constructing the codemod or when calling ``store_stub_in_context``. 

651 """ 

652 

653 CONTEXT_KEY = "ApplyTypeAnnotationsVisitor" 

654 

655 def __init__( 

656 self, 

657 context: CodemodContext, 

658 annotations: Optional[Annotations] = None, 

659 overwrite_existing_annotations: bool = False, 

660 use_future_annotations: bool = False, 

661 strict_posargs_matching: bool = True, 

662 strict_annotation_matching: bool = False, 

663 always_qualify_annotations: bool = False, 

664 ) -> None: 

665 super().__init__(context) 

666 # Qualifier for storing the canonical name of the current function. 

667 self.qualifier: List[str] = [] 

668 self.annotations: Annotations = ( 

669 Annotations.empty() if annotations is None else annotations 

670 ) 

671 self.toplevel_annotations: Dict[str, cst.Annotation] = {} 

672 self.visited_classes: Set[str] = set() 

673 self.overwrite_existing_annotations = overwrite_existing_annotations 

674 self.use_future_annotations = use_future_annotations 

675 self.strict_posargs_matching = strict_posargs_matching 

676 self.strict_annotation_matching = strict_annotation_matching 

677 self.always_qualify_annotations = always_qualify_annotations 

678 

679 # We use this to determine the end of the import block so that we can 

680 # insert top-level annotations. 

681 self.import_statements: List[cst.ImportFrom] = [] 

682 

683 # We use this to report annotations added, as well as to determine 

684 # whether to abandon the codemod in edge cases where we may have 

685 # only made changes to the imports. 

686 self.annotation_counts: AnnotationCounts = AnnotationCounts() 

687 

688 # We use this to collect typevars, to avoid importing existing ones from the pyi file 

689 self.current_assign: Optional[cst.Assign] = None 

690 self.typevars: Dict[str, cst.Assign] = {} 

691 

692 # Global variables and classes defined on the toplevel of the target module. 

693 # Used to help determine which names we need to check are in scope, and add 

694 # quotations to avoid undefined forward references in type annotations. 

695 self.global_names: Set[str] = set() 

696 

697 # We use this to avoid annotating multiple assignments to the same 

698 # symbol in a given scope 

699 self.already_annotated: Set[str] = set() 

700 

701 @staticmethod 

702 def store_stub_in_context( 

703 context: CodemodContext, 

704 stub: cst.Module, 

705 overwrite_existing_annotations: bool = False, 

706 use_future_annotations: bool = False, 

707 strict_posargs_matching: bool = True, 

708 strict_annotation_matching: bool = False, 

709 always_qualify_annotations: bool = False, 

710 ) -> None: 

711 """ 

712 Store a stub module in the :class:`~libcst.codemod.CodemodContext` so 

713 that type annotations from the stub can be applied in a later 

714 invocation of this class. 

715 

716 If the ``overwrite_existing_annotations`` flag is ``True``, the 

717 codemod will overwrite any existing annotations. 

718 

719 If you call this function multiple times, only the last values of 

720 ``stub`` and ``overwrite_existing_annotations`` will take effect. 

721 """ 

722 context.scratch[ApplyTypeAnnotationsVisitor.CONTEXT_KEY] = ( 

723 stub, 

724 overwrite_existing_annotations, 

725 use_future_annotations, 

726 strict_posargs_matching, 

727 strict_annotation_matching, 

728 always_qualify_annotations, 

729 ) 

730 

731 def transform_module_impl( 

732 self, 

733 tree: cst.Module, 

734 ) -> cst.Module: 

735 """ 

736 Collect type annotations from all stubs and apply them to ``tree``. 

737 

738 Gather existing imports from ``tree`` so that we don't add duplicate imports. 

739 

740 Gather global names from ``tree`` so forward references are quoted. 

741 """ 

742 import_gatherer = GatherImportsVisitor(CodemodContext()) 

743 tree.visit(import_gatherer) 

744 existing_import_names = _get_imported_names(import_gatherer.all_imports) 

745 

746 global_names_gatherer = GatherGlobalNamesVisitor(CodemodContext()) 

747 tree.visit(global_names_gatherer) 

748 self.global_names = global_names_gatherer.global_names.union( 

749 global_names_gatherer.class_names 

750 ) 

751 

752 context_contents = self.context.scratch.get( 

753 ApplyTypeAnnotationsVisitor.CONTEXT_KEY 

754 ) 

755 if context_contents is not None: 

756 ( 

757 stub, 

758 overwrite_existing_annotations, 

759 use_future_annotations, 

760 strict_posargs_matching, 

761 strict_annotation_matching, 

762 always_qualify_annotations, 

763 ) = context_contents 

764 self.overwrite_existing_annotations = ( 

765 self.overwrite_existing_annotations or overwrite_existing_annotations 

766 ) 

767 self.use_future_annotations = ( 

768 self.use_future_annotations or use_future_annotations 

769 ) 

770 self.strict_posargs_matching = ( 

771 self.strict_posargs_matching and strict_posargs_matching 

772 ) 

773 self.strict_annotation_matching = ( 

774 self.strict_annotation_matching or strict_annotation_matching 

775 ) 

776 self.always_qualify_annotations = ( 

777 self.always_qualify_annotations or always_qualify_annotations 

778 ) 

779 module_imports = self._get_module_imports(stub, import_gatherer) 

780 visitor = TypeCollector(existing_import_names, module_imports, self.context) 

781 cst.MetadataWrapper(stub).visit(visitor) 

782 self.annotations.update(visitor.annotations) 

783 

784 if self.use_future_annotations: 

785 AddImportsVisitor.add_needed_import( 

786 self.context, "__future__", "annotations" 

787 ) 

788 tree_with_imports = AddImportsVisitor(self.context).transform_module(tree) 

789 

790 tree_with_changes = tree_with_imports.visit(self) 

791 

792 # don't modify the imports if we didn't actually add any type information 

793 if self.annotation_counts.any_changes_applied(): 

794 return tree_with_changes 

795 else: 

796 return tree 

797 

798 # helpers for collecting type information from the stub files 

799 

800 def _get_module_imports( # noqa: C901: too complex 

801 self, stub: cst.Module, existing_import_gatherer: GatherImportsVisitor 

802 ) -> Dict[str, ImportItem]: 

803 """Returns a dict of modules that need to be imported to qualify symbols.""" 

804 # We correlate all imported symbols, e.g. foo.bar.Baz, with a list of module 

805 # and from imports. If the same unqualified symbol is used from different 

806 # modules, we give preference to an explicit from-import if any, and qualify 

807 # everything else by importing the module. 

808 # 

809 # e.g. the following stub: 

810 # import foo as quux 

811 # from bar import Baz as X 

812 # def f(x: X) -> quux.X: ... 

813 # will return {'foo': ImportItem("foo", "quux")}. When the apply type 

814 # annotation visitor hits `quux.X` it will retrieve the canonical name 

815 # `foo.X` and then note that `foo` is in the module imports map, so it will 

816 # leave the symbol qualified. 

817 import_gatherer = GatherImportsVisitor(CodemodContext()) 

818 stub.visit(import_gatherer) 

819 symbol_map = import_gatherer.symbol_mapping 

820 existing_import_names = _get_imported_names( 

821 existing_import_gatherer.all_imports 

822 ) 

823 symbol_collector = ImportedSymbolCollector(existing_import_names, self.context) 

824 cst.MetadataWrapper(stub).visit(symbol_collector) 

825 module_imports = {} 

826 for sym, imported_symbols in symbol_collector.imported_symbols.items(): 

827 existing = existing_import_gatherer.symbol_mapping.get(sym) 

828 if existing and any( 

829 s.module_name != existing.module_name for s in imported_symbols 

830 ): 

831 # If a symbol is imported in the main file, we have to qualify 

832 # it when imported from a different module in the stub file. 

833 used = True 

834 elif len(imported_symbols) == 1 and not self.always_qualify_annotations: 

835 # If we have a single use of a new symbol we can from-import it 

836 continue 

837 else: 

838 # There are multiple occurrences in the stub file and none in 

839 # the main file. At least one can be from-imported. 

840 used = False 

841 for imp_sym in imported_symbols: 

842 if not imp_sym.symbol: 

843 continue 

844 imp = symbol_map.get(imp_sym.symbol) 

845 if self.always_qualify_annotations and sym not in existing_import_names: 

846 # Override 'always qualify' if this is a typing import, or 

847 # the main file explicitly from-imports a symbol. 

848 if imp and imp.module_name != "typing": 

849 module_imports[imp.module_name] = imp 

850 else: 

851 imp = symbol_map.get(imp_sym.module_symbol) 

852 if imp: 

853 module_imports[imp.module_name] = imp 

854 elif not used and imp and imp.module_name == imp_sym.module_name: 

855 # We can only import a symbol directly once. 

856 used = True 

857 elif sym in existing_import_names: 

858 if imp: 

859 module_imports[imp.module_name] = imp 

860 else: 

861 imp = symbol_map.get(imp_sym.module_symbol) 

862 if imp: 

863 # imp will be None in corner cases like 

864 # import foo.bar as Baz 

865 # x: Baz 

866 # which is technically valid python but nonsensical as a 

867 # type annotation. Dropping it on the floor for now. 

868 module_imports[imp.module_name] = imp 

869 return module_imports 

870 

871 # helpers for processing annotation nodes 

872 def _quote_future_annotations(self, annotation: cst.Annotation) -> cst.Annotation: 

873 # TODO: We probably want to make sure references to classes defined in the current 

874 # module come to us fully qualified - so we can do the dequalification here and 

875 # know to look for what is in-scope without also catching builtins like "None" in the 

876 # quoting. This should probably also be extended to handle what imports are in scope, 

877 # as well as subscriptable types. 

878 # Note: We are collecting all imports and passing this to the type collector grabbing 

879 # annotations from the stub file; should consolidate import handling somewhere too. 

880 node = annotation.annotation 

881 if ( 

882 isinstance(node, cst.Name) 

883 and (node.value in self.global_names) 

884 and not (node.value in self.visited_classes) 

885 ): 

886 return annotation.with_changes( 

887 annotation=cst.SimpleString(value=f'"{node.value}"') 

888 ) 

889 return annotation 

890 

891 # smart constructors: all applied annotations happen via one of these 

892 

893 def _apply_annotation_to_attribute_or_global( 

894 self, 

895 name: str, 

896 annotation: cst.Annotation, 

897 value: Optional[cst.BaseExpression], 

898 ) -> cst.AnnAssign: 

899 if len(self.qualifier) == 0: 

900 self.annotation_counts.global_annotations += 1 

901 else: 

902 self.annotation_counts.attribute_annotations += 1 

903 return cst.AnnAssign( 

904 cst.Name(name), 

905 self._quote_future_annotations(annotation), 

906 value, 

907 ) 

908 

909 def _apply_annotation_to_parameter( 

910 self, 

911 parameter: cst.Param, 

912 annotation: cst.Annotation, 

913 ) -> cst.Param: 

914 self.annotation_counts.parameter_annotations += 1 

915 return parameter.with_changes( 

916 annotation=self._quote_future_annotations(annotation), 

917 ) 

918 

919 def _apply_annotation_to_return( 

920 self, 

921 function_def: cst.FunctionDef, 

922 annotation: cst.Annotation, 

923 ) -> cst.FunctionDef: 

924 self.annotation_counts.return_annotations += 1 

925 return function_def.with_changes( 

926 returns=self._quote_future_annotations(annotation), 

927 ) 

928 

929 # private methods used in the visit and leave methods 

930 

931 def _qualifier_name(self) -> str: 

932 return ".".join(self.qualifier) 

933 

934 def _annotate_single_target( 

935 self, 

936 node: cst.Assign, 

937 updated_node: cst.Assign, 

938 ) -> Union[cst.Assign, cst.AnnAssign]: 

939 only_target = node.targets[0].target 

940 if isinstance(only_target, (cst.Tuple, cst.List)): 

941 for element in only_target.elements: 

942 value = element.value 

943 name = get_full_name_for_node(value) 

944 if name is not None and name != "_": 

945 self._add_to_toplevel_annotations(name) 

946 elif isinstance(only_target, (cst.Subscript)): 

947 pass 

948 else: 

949 name = get_full_name_for_node(only_target) 

950 if name is not None: 

951 self.qualifier.append(name) 

952 qualifier_name = self._qualifier_name() 

953 if qualifier_name in self.annotations.attributes and not isinstance( 

954 only_target, (cst.Attribute, cst.Subscript) 

955 ): 

956 if qualifier_name not in self.already_annotated: 

957 self.already_annotated.add(qualifier_name) 

958 annotation = self.annotations.attributes[qualifier_name] 

959 self.qualifier.pop() 

960 return self._apply_annotation_to_attribute_or_global( 

961 name=name, 

962 annotation=annotation, 

963 value=node.value, 

964 ) 

965 else: 

966 self.qualifier.pop() 

967 return updated_node 

968 

969 def _split_module( 

970 self, 

971 module: cst.Module, 

972 updated_module: cst.Module, 

973 ) -> Tuple[ 

974 List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]], 

975 List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]], 

976 ]: 

977 import_add_location = 0 

978 # This works under the principle that while we might modify node contents, 

979 # we have yet to modify the number of statements. So we can match on the 

980 # original tree but break up the statements of the modified tree. If we 

981 # change this assumption in this visitor, we will have to change this code. 

982 for i, statement in enumerate(module.body): 

983 if isinstance(statement, cst.SimpleStatementLine): 

984 for possible_import in statement.body: 

985 for last_import in self.import_statements: 

986 if possible_import is last_import: 

987 import_add_location = i + 1 

988 break 

989 

990 return ( 

991 list(updated_module.body[:import_add_location]), 

992 list(updated_module.body[import_add_location:]), 

993 ) 

994 

995 def _add_to_toplevel_annotations( 

996 self, 

997 name: str, 

998 ) -> None: 

999 self.qualifier.append(name) 

1000 if self._qualifier_name() in self.annotations.attributes: 

1001 annotation = self.annotations.attributes[self._qualifier_name()] 

1002 self.toplevel_annotations[name] = annotation 

1003 self.qualifier.pop() 

1004 

1005 def _update_parameters( 

1006 self, 

1007 annotations: FunctionAnnotation, 

1008 updated_node: cst.FunctionDef, 

1009 ) -> cst.Parameters: 

1010 # Update params and default params with annotations 

1011 # Don't override existing annotations or default values unless asked 

1012 # to overwrite existing annotations. 

1013 def update_annotation( 

1014 parameters: Sequence[cst.Param], 

1015 annotations: Sequence[cst.Param], 

1016 positional: bool, 

1017 ) -> List[cst.Param]: 

1018 parameter_annotations = {} 

1019 annotated_parameters = [] 

1020 positional = positional and not self.strict_posargs_matching 

1021 for i, parameter in enumerate(annotations): 

1022 key = i if positional else parameter.name.value 

1023 if parameter.annotation: 

1024 parameter_annotations[key] = parameter.annotation.with_changes( 

1025 whitespace_before_indicator=cst.SimpleWhitespace(value="") 

1026 ) 

1027 for i, parameter in enumerate(parameters): 

1028 key = i if positional else parameter.name.value 

1029 if key in parameter_annotations and ( 

1030 self.overwrite_existing_annotations or not parameter.annotation 

1031 ): 

1032 parameter = self._apply_annotation_to_parameter( 

1033 parameter=parameter, 

1034 annotation=parameter_annotations[key], 

1035 ) 

1036 annotated_parameters.append(parameter) 

1037 return annotated_parameters 

1038 

1039 return updated_node.params.with_changes( 

1040 params=update_annotation( 

1041 updated_node.params.params, 

1042 annotations.parameters.params, 

1043 positional=True, 

1044 ), 

1045 kwonly_params=update_annotation( 

1046 updated_node.params.kwonly_params, 

1047 annotations.parameters.kwonly_params, 

1048 positional=False, 

1049 ), 

1050 posonly_params=update_annotation( 

1051 updated_node.params.posonly_params, 

1052 annotations.parameters.posonly_params, 

1053 positional=True, 

1054 ), 

1055 ) 

1056 

1057 def _insert_empty_line( 

1058 self, 

1059 statements: List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]], 

1060 ) -> List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]]: 

1061 if len(statements) < 1: 

1062 # No statements, nothing to add to 

1063 return statements 

1064 if len(statements[0].leading_lines) == 0: 

1065 # Statement has no leading lines, add one! 

1066 return [ 

1067 statements[0].with_changes(leading_lines=(cst.EmptyLine(),)), 

1068 *statements[1:], 

1069 ] 

1070 if statements[0].leading_lines[0].comment is None: 

1071 # First line is empty, so its safe to leave as-is 

1072 return statements 

1073 # Statement has a comment first line, so lets add one more empty line 

1074 return [ 

1075 statements[0].with_changes( 

1076 leading_lines=(cst.EmptyLine(), *statements[0].leading_lines) 

1077 ), 

1078 *statements[1:], 

1079 ] 

1080 

1081 def _match_signatures( # noqa: C901: Too complex 

1082 self, 

1083 function: cst.FunctionDef, 

1084 annotations: FunctionAnnotation, 

1085 ) -> bool: 

1086 """Check that function annotations on both signatures are compatible.""" 

1087 

1088 def compatible( 

1089 p: Optional[cst.Annotation], 

1090 q: Optional[cst.Annotation], 

1091 ) -> bool: 

1092 if ( 

1093 self.overwrite_existing_annotations 

1094 or not _is_non_sentinel(p) 

1095 or not _is_non_sentinel(q) 

1096 ): 

1097 return True 

1098 if not self.strict_annotation_matching: 

1099 # We will not overwrite clashing annotations, but the signature as a 

1100 # whole will be marked compatible so that holes can be filled in. 

1101 return True 

1102 return p.annotation.deep_equals(q.annotation) # pyre-ignore[16] 

1103 

1104 def match_posargs( 

1105 ps: Sequence[cst.Param], 

1106 qs: Sequence[cst.Param], 

1107 ) -> bool: 

1108 if len(ps) != len(qs): 

1109 return False 

1110 for p, q in zip(ps, qs): 

1111 if self.strict_posargs_matching and not p.name.value == q.name.value: 

1112 return False 

1113 if not compatible(p.annotation, q.annotation): 

1114 return False 

1115 return True 

1116 

1117 def match_kwargs( 

1118 ps: Sequence[cst.Param], 

1119 qs: Sequence[cst.Param], 

1120 ) -> bool: 

1121 ps_dict = {x.name.value: x for x in ps} 

1122 qs_dict = {x.name.value: x for x in qs} 

1123 if set(ps_dict.keys()) != set(qs_dict.keys()): 

1124 return False 

1125 for k in ps_dict.keys(): 

1126 if not compatible(ps_dict[k].annotation, qs_dict[k].annotation): 

1127 return False 

1128 return True 

1129 

1130 def match_star( 

1131 p: StarParamType, 

1132 q: StarParamType, 

1133 ) -> bool: 

1134 return _is_non_sentinel(p) == _is_non_sentinel(q) 

1135 

1136 def match_params( 

1137 f: cst.FunctionDef, 

1138 g: FunctionAnnotation, 

1139 ) -> bool: 

1140 p, q = f.params, g.parameters 

1141 return ( 

1142 match_posargs(p.params, q.params) 

1143 and match_posargs(p.posonly_params, q.posonly_params) 

1144 and match_kwargs(p.kwonly_params, q.kwonly_params) 

1145 and match_star(p.star_arg, q.star_arg) 

1146 and match_star(p.star_kwarg, q.star_kwarg) 

1147 ) 

1148 

1149 def match_return( 

1150 f: cst.FunctionDef, 

1151 g: FunctionAnnotation, 

1152 ) -> bool: 

1153 return compatible(f.returns, g.returns) 

1154 

1155 return match_params(function, annotations) and match_return( 

1156 function, annotations 

1157 ) 

1158 

1159 # transform API methods 

1160 

1161 def visit_ClassDef( 

1162 self, 

1163 node: cst.ClassDef, 

1164 ) -> None: 

1165 self.qualifier.append(node.name.value) 

1166 

1167 def leave_ClassDef( 

1168 self, 

1169 original_node: cst.ClassDef, 

1170 updated_node: cst.ClassDef, 

1171 ) -> cst.ClassDef: 

1172 self.visited_classes.add(original_node.name.value) 

1173 cls_name = ".".join(self.qualifier) 

1174 self.qualifier.pop() 

1175 definition = self.annotations.class_definitions.get(cls_name) 

1176 if definition: 

1177 b1 = _find_generic_base(definition) 

1178 b2 = _find_generic_base(updated_node) 

1179 if b1 and not b2: 

1180 new_bases = list(updated_node.bases) + [b1] 

1181 self.annotation_counts.typevars_and_generics_added += 1 

1182 return updated_node.with_changes(bases=new_bases) 

1183 return updated_node 

1184 

1185 def visit_FunctionDef( 

1186 self, 

1187 node: cst.FunctionDef, 

1188 ) -> bool: 

1189 self.qualifier.append(node.name.value) 

1190 # pyi files don't support inner functions, return False to stop the traversal. 

1191 return False 

1192 

1193 def leave_FunctionDef( 

1194 self, 

1195 original_node: cst.FunctionDef, 

1196 updated_node: cst.FunctionDef, 

1197 ) -> cst.FunctionDef: 

1198 key = FunctionKey.make(self._qualifier_name(), updated_node.params) 

1199 self.qualifier.pop() 

1200 if key in self.annotations.functions: 

1201 function_annotation = self.annotations.functions[key] 

1202 # Only add new annotation if: 

1203 # * we have matching function signatures and 

1204 # * we are explicitly told to overwrite existing annotations or 

1205 # * there is no existing annotation 

1206 if not self._match_signatures(updated_node, function_annotation): 

1207 return updated_node 

1208 set_return_annotation = ( 

1209 self.overwrite_existing_annotations or updated_node.returns is None 

1210 ) 

1211 if set_return_annotation and function_annotation.returns is not None: 

1212 updated_node = self._apply_annotation_to_return( 

1213 function_def=updated_node, 

1214 annotation=function_annotation.returns, 

1215 ) 

1216 # Don't override default values when annotating functions 

1217 new_parameters = self._update_parameters(function_annotation, updated_node) 

1218 return updated_node.with_changes(params=new_parameters) 

1219 return updated_node 

1220 

1221 def visit_Assign( 

1222 self, 

1223 node: cst.Assign, 

1224 ) -> None: 

1225 self.current_assign = node 

1226 

1227 @m.call_if_inside(m.Assign()) 

1228 @m.visit(m.Call(func=m.Name("TypeVar"))) 

1229 def record_typevar( 

1230 self, 

1231 node: cst.Call, 

1232 ) -> None: 

1233 # pyre-ignore current_assign is never None here 

1234 name = get_full_name_for_node(self.current_assign.targets[0].target) 

1235 if name is not None: 

1236 # Preserve the whole node, even though we currently just use the 

1237 # name, so that we can match bounds and variance at some point and 

1238 # determine if two typevars with the same name are indeed the same. 

1239 

1240 # pyre-ignore current_assign is never None here 

1241 self.typevars[name] = self.current_assign 

1242 self.current_assign = None 

1243 

1244 def leave_Assign( 

1245 self, 

1246 original_node: cst.Assign, 

1247 updated_node: cst.Assign, 

1248 ) -> Union[cst.Assign, cst.AnnAssign]: 

1249 self.current_assign = None 

1250 

1251 if len(original_node.targets) > 1: 

1252 for assign in original_node.targets: 

1253 target = assign.target 

1254 if isinstance(target, (cst.Name, cst.Attribute)): 

1255 name = get_full_name_for_node(target) 

1256 if name is not None and name != "_": 

1257 # Add separate top-level annotations for `a = b = 1` 

1258 # as `a: int` and `b: int`. 

1259 self._add_to_toplevel_annotations(name) 

1260 return updated_node 

1261 else: 

1262 return self._annotate_single_target(original_node, updated_node) 

1263 

1264 def leave_ImportFrom( 

1265 self, 

1266 original_node: cst.ImportFrom, 

1267 updated_node: cst.ImportFrom, 

1268 ) -> cst.ImportFrom: 

1269 self.import_statements.append(original_node) 

1270 return updated_node 

1271 

1272 def leave_Module( 

1273 self, 

1274 original_node: cst.Module, 

1275 updated_node: cst.Module, 

1276 ) -> cst.Module: 

1277 fresh_class_definitions = [ 

1278 definition 

1279 for name, definition in self.annotations.class_definitions.items() 

1280 if name not in self.visited_classes 

1281 ] 

1282 

1283 # NOTE: The entire change will also be abandoned if 

1284 # self.annotation_counts is all 0s, so if adding any new category make 

1285 # sure to record it there. 

1286 if not ( 

1287 self.toplevel_annotations 

1288 or fresh_class_definitions 

1289 or self.annotations.typevars 

1290 ): 

1291 return updated_node 

1292 

1293 toplevel_statements = [] 

1294 # First, find the insertion point for imports 

1295 statements_before_imports, statements_after_imports = self._split_module( 

1296 original_node, updated_node 

1297 ) 

1298 

1299 # Make sure there's at least one empty line before the first non-import 

1300 statements_after_imports = self._insert_empty_line(statements_after_imports) 

1301 

1302 for name, annotation in self.toplevel_annotations.items(): 

1303 annotated_assign = self._apply_annotation_to_attribute_or_global( 

1304 name=name, 

1305 annotation=annotation, 

1306 value=None, 

1307 ) 

1308 toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) 

1309 

1310 # TypeVar definitions could be scattered through the file, so do not 

1311 # attempt to put new ones with existing ones, just add them at the top. 

1312 typevars = { 

1313 k: v for k, v in self.annotations.typevars.items() if k not in self.typevars 

1314 } 

1315 if typevars: 

1316 for var, stmt in typevars.items(): 

1317 toplevel_statements.append(cst.Newline()) 

1318 toplevel_statements.append(stmt) 

1319 self.annotation_counts.typevars_and_generics_added += 1 

1320 toplevel_statements.append(cst.Newline()) 

1321 

1322 self.annotation_counts.classes_added = len(fresh_class_definitions) 

1323 toplevel_statements.extend(fresh_class_definitions) 

1324 

1325 return updated_node.with_changes( 

1326 body=[ 

1327 *statements_before_imports, 

1328 *toplevel_statements, 

1329 *statements_after_imports, 

1330 ] 

1331 )