1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""
6Astroid hook for the dataclasses library.
7
8Support built-in dataclasses, pydantic.dataclasses, and marshmallow_dataclass-annotated
9dataclasses. References:
10- https://docs.python.org/3/library/dataclasses.html
11- https://pydantic-docs.helpmanual.io/usage/dataclasses/
12- https://lovasoa.github.io/marshmallow_dataclass/
13"""
14
15from __future__ import annotations
16
17from collections.abc import Iterator
18from typing import Literal
19
20from astroid import bases, context, nodes
21from astroid.brain.helpers import is_class_var
22from astroid.builder import parse
23from astroid.const import PY313_PLUS
24from astroid.exceptions import (
25 AstroidSyntaxError,
26 InferenceError,
27 MroError,
28 UseInferenceDefault,
29)
30from astroid.inference_tip import inference_tip
31from astroid.manager import AstroidManager
32from astroid.typing import InferenceResult
33from astroid.util import Uninferable, UninferableBase, safe_infer
34
35_FieldDefaultReturn = (
36 None
37 | tuple[Literal["default"], nodes.NodeNG]
38 | tuple[Literal["default_factory"], nodes.Call]
39)
40
41DATACLASSES_DECORATORS = frozenset(("dataclass",))
42FIELD_NAME = "field"
43DATACLASS_MODULES = frozenset(
44 ("dataclasses", "marshmallow_dataclass", "pydantic.dataclasses")
45)
46DEFAULT_FACTORY = "_HAS_DEFAULT_FACTORY" # based on typing.py
47
48
49def is_decorated_with_dataclass(
50 node: nodes.ClassDef, decorator_names: frozenset[str] = DATACLASSES_DECORATORS
51) -> bool:
52 """Return True if a decorated node has a `dataclass` decorator applied."""
53 if not (isinstance(node, nodes.ClassDef) and node.decorators):
54 return False
55
56 return any(
57 _looks_like_dataclass_decorator(decorator_attribute, decorator_names)
58 for decorator_attribute in node.decorators.nodes
59 )
60
61
62def dataclass_transform(node: nodes.ClassDef) -> nodes.ClassDef | None:
63 """Rewrite a dataclass to be easily understood by pylint."""
64 node.is_dataclass = True
65
66 for assign_node in _get_dataclass_attributes(node):
67 name = assign_node.target.name
68
69 rhs_node = nodes.Unknown(
70 lineno=assign_node.lineno,
71 col_offset=assign_node.col_offset,
72 parent=assign_node,
73 )
74 rhs_node = AstroidManager().visit_transforms(rhs_node)
75 node.instance_attrs[name] = [rhs_node]
76
77 if not _check_generate_dataclass_init(node):
78 return None
79
80 kw_only_decorated = False
81 if node.decorators.nodes:
82 for decorator in node.decorators.nodes:
83 if not isinstance(decorator, nodes.Call):
84 kw_only_decorated = False
85 break
86 for keyword in decorator.keywords:
87 if keyword.arg == "kw_only":
88 kw_only_decorated = keyword.value.bool_value() is True
89
90 init_str = _generate_dataclass_init(
91 node,
92 list(_get_dataclass_attributes(node, init=True)),
93 kw_only_decorated,
94 )
95
96 try:
97 init_node = parse(init_str)["__init__"]
98 except AstroidSyntaxError:
99 pass
100 else:
101 init_node.parent = node
102 init_node.lineno, init_node.col_offset = None, None
103 node.locals["__init__"] = [init_node]
104
105 root = node.root()
106 if DEFAULT_FACTORY not in root.locals:
107 new_assign = parse(f"{DEFAULT_FACTORY} = object()").body[0]
108 new_assign.parent = root
109 root.locals[DEFAULT_FACTORY] = [new_assign.targets[0]]
110 return node
111
112
113def _get_dataclass_attributes(
114 node: nodes.ClassDef, init: bool = False
115) -> Iterator[nodes.AnnAssign]:
116 """Yield the AnnAssign nodes of dataclass attributes for the node.
117
118 If init is True, also include InitVars.
119 """
120 for assign_node in node.body:
121 if not (
122 isinstance(assign_node, nodes.AnnAssign)
123 and isinstance(assign_node.target, nodes.AssignName)
124 ):
125 continue
126
127 # Annotation is never None
128 if is_class_var(assign_node.annotation): # type: ignore[arg-type]
129 continue
130
131 if _is_keyword_only_sentinel(assign_node.annotation):
132 continue
133
134 # Annotation is never None
135 if not init and _is_init_var(assign_node.annotation): # type: ignore[arg-type]
136 continue
137
138 yield assign_node
139
140
141def _check_generate_dataclass_init(node: nodes.ClassDef) -> bool:
142 """Return True if we should generate an __init__ method for node.
143
144 This is True when:
145 - node doesn't define its own __init__ method
146 - the dataclass decorator was called *without* the keyword argument init=False
147 """
148 if "__init__" in node.locals:
149 return False
150
151 found = None
152
153 for decorator_attribute in node.decorators.nodes:
154 if not isinstance(decorator_attribute, nodes.Call):
155 continue
156
157 if _looks_like_dataclass_decorator(decorator_attribute):
158 found = decorator_attribute
159
160 if found is None:
161 return True
162
163 # Check for keyword arguments of the form init=False
164 return not any(
165 keyword.arg == "init"
166 and keyword.value.bool_value() is False # type: ignore[union-attr] # value is never None
167 for keyword in found.keywords
168 )
169
170
171def _find_arguments_from_base_classes(
172 node: nodes.ClassDef,
173) -> tuple[
174 dict[str, tuple[str | None, str | None]], dict[str, tuple[str | None, str | None]]
175]:
176 """Iterate through all bases and get their typing and defaults."""
177 pos_only_store: dict[str, tuple[str | None, str | None]] = {}
178 kw_only_store: dict[str, tuple[str | None, str | None]] = {}
179 # See TODO down below
180 # all_have_defaults = True
181
182 try:
183 mro = node.mro()
184 except MroError:
185 return pos_only_store, kw_only_store
186
187 for base in reversed(mro):
188 if not base.is_dataclass:
189 continue
190 try:
191 base_init: nodes.FunctionDef = base.locals["__init__"][0]
192 except KeyError:
193 continue
194
195 pos_only, kw_only = base_init.args._get_arguments_data()
196 for posarg, data in pos_only.items():
197 # if data[1] is None:
198 # if all_have_defaults and pos_only_store:
199 # # TODO: This should return an Uninferable as this would raise
200 # # a TypeError at runtime. However, transforms can't return
201 # # Uninferables currently.
202 # pass
203 # all_have_defaults = False
204 pos_only_store[posarg] = data
205
206 for kwarg, data in kw_only.items():
207 kw_only_store[kwarg] = data
208 return pos_only_store, kw_only_store
209
210
211def _parse_arguments_into_strings(
212 pos_only_store: dict[str, tuple[str | None, str | None]],
213 kw_only_store: dict[str, tuple[str | None, str | None]],
214) -> tuple[str, str]:
215 """Parse positional and keyword arguments into strings for an __init__ method."""
216 pos_only, kw_only = "", ""
217 for pos_arg, data in pos_only_store.items():
218 pos_only += pos_arg
219 if data[0]:
220 pos_only += ": " + data[0]
221 if data[1]:
222 pos_only += " = " + data[1]
223 pos_only += ", "
224 for kw_arg, data in kw_only_store.items():
225 kw_only += kw_arg
226 if data[0]:
227 kw_only += ": " + data[0]
228 if data[1]:
229 kw_only += " = " + data[1]
230 kw_only += ", "
231
232 return pos_only, kw_only
233
234
235def _get_previous_field_default(node: nodes.ClassDef, name: str) -> nodes.NodeNG | None:
236 """Get the default value of a previously defined field."""
237 try:
238 mro = node.mro()
239 except MroError:
240 return None
241
242 for base in reversed(mro):
243 if not base.is_dataclass:
244 continue
245 if name in base.locals:
246 for assign in base.locals[name]:
247 if (
248 isinstance(assign.parent, nodes.AnnAssign)
249 and assign.parent.value
250 and isinstance(assign.parent.value, nodes.Call)
251 and _looks_like_dataclass_field_call(assign.parent.value)
252 ):
253 default = _get_field_default(assign.parent.value)
254 if default:
255 return default[1]
256 return None
257
258
259def _generate_dataclass_init(
260 node: nodes.ClassDef, assigns: list[nodes.AnnAssign], kw_only_decorated: bool
261) -> str:
262 """Return an init method for a dataclass given the targets."""
263 # pylint: disable = too-many-locals, too-many-branches, too-many-statements
264
265 params: list[str] = []
266 kw_only_params: list[str] = []
267 assignments: list[str] = []
268
269 prev_pos_only_store, prev_kw_only_store = _find_arguments_from_base_classes(node)
270
271 for assign in assigns:
272 name, annotation, value = assign.target.name, assign.annotation, assign.value
273
274 # Check whether this assign is overriden by a property assignment
275 property_node: nodes.FunctionDef | None = None
276 for additional_assign in node.locals[name]:
277 if not isinstance(additional_assign, nodes.FunctionDef):
278 continue
279 if not additional_assign.decorators:
280 continue
281 if "builtins.property" in additional_assign.decoratornames():
282 property_node = additional_assign
283 break
284
285 is_field = isinstance(value, nodes.Call) and _looks_like_dataclass_field_call(
286 value, check_scope=False
287 )
288
289 if is_field:
290 # Skip any fields that have `init=False`
291 if any(
292 keyword.arg == "init" and (keyword.value.bool_value() is False)
293 for keyword in value.keywords # type: ignore[union-attr] # value is never None
294 ):
295 # Also remove the name from the previous arguments to be inserted later
296 prev_pos_only_store.pop(name, None)
297 prev_kw_only_store.pop(name, None)
298 continue
299
300 if _is_init_var(annotation): # type: ignore[arg-type] # annotation is never None
301 init_var = True
302 if isinstance(annotation, nodes.Subscript):
303 annotation = annotation.slice
304 else:
305 # Cannot determine type annotation for parameter from InitVar
306 annotation = None
307 assignment_str = ""
308 else:
309 init_var = False
310 assignment_str = f"self.{name} = {name}"
311
312 ann_str, default_str = None, None
313 if annotation is not None:
314 ann_str = annotation.as_string()
315
316 if value:
317 if is_field:
318 result = _get_field_default(value) # type: ignore[arg-type]
319 if result:
320 default_type, default_node = result
321 if default_type == "default":
322 default_str = default_node.as_string()
323 elif default_type == "default_factory":
324 default_str = DEFAULT_FACTORY
325 assignment_str = (
326 f"self.{name} = {default_node.as_string()} "
327 f"if {name} is {DEFAULT_FACTORY} else {name}"
328 )
329 else:
330 default_str = value.as_string()
331 elif property_node:
332 # We set the result of the property call as default
333 # This hides the fact that this would normally be a 'property object'
334 # But we can't represent those as string
335 try:
336 # Call str to make sure also Uninferable gets stringified
337 default_str = str(
338 next(property_node.infer_call_result(None)).as_string()
339 )
340 except (InferenceError, StopIteration):
341 pass
342 else:
343 # Even with `init=False` the default value still can be propogated to
344 # later assignments. Creating weird signatures like:
345 # (self, a: str = 1) -> None
346 previous_default = _get_previous_field_default(node, name)
347 if previous_default:
348 default_str = previous_default.as_string()
349
350 # Construct the param string to add to the init if necessary
351 param_str = name
352 if ann_str is not None:
353 param_str += f": {ann_str}"
354 if default_str is not None:
355 param_str += f" = {default_str}"
356
357 # If the field is a kw_only field, we need to add it to the kw_only_params
358 # This overwrites whether or not the class is kw_only decorated
359 if is_field:
360 kw_only = [k for k in value.keywords if k.arg == "kw_only"] # type: ignore[union-attr]
361 if kw_only:
362 if kw_only[0].value.bool_value() is True:
363 kw_only_params.append(param_str)
364 else:
365 params.append(param_str)
366 continue
367 # If kw_only decorated, we need to add all parameters to the kw_only_params
368 if kw_only_decorated:
369 if name in prev_kw_only_store:
370 prev_kw_only_store[name] = (ann_str, default_str)
371 else:
372 kw_only_params.append(param_str)
373 else:
374 # If the name was previously seen, overwrite that data
375 # pylint: disable-next=else-if-used
376 if name in prev_pos_only_store:
377 prev_pos_only_store[name] = (ann_str, default_str)
378 elif name in prev_kw_only_store:
379 params = [name, *params]
380 prev_kw_only_store.pop(name)
381 else:
382 params.append(param_str)
383
384 if not init_var:
385 assignments.append(assignment_str)
386
387 prev_pos_only, prev_kw_only = _parse_arguments_into_strings(
388 prev_pos_only_store, prev_kw_only_store
389 )
390
391 # Construct the new init method paramter string
392 # First we do the positional only parameters, making sure to add the
393 # the self parameter and the comma to allow adding keyword only parameters
394 params_string = "" if "self" in prev_pos_only else "self, "
395 params_string += prev_pos_only + ", ".join(params)
396 if not params_string.endswith(", "):
397 params_string += ", "
398
399 # Then we add the keyword only parameters
400 if prev_kw_only or kw_only_params:
401 params_string += "*, "
402 params_string += f"{prev_kw_only}{', '.join(kw_only_params)}"
403
404 assignments_string = "\n ".join(assignments) if assignments else "pass"
405 return f"def __init__({params_string}) -> None:\n {assignments_string}"
406
407
408def infer_dataclass_attribute(
409 node: nodes.Unknown, ctx: context.InferenceContext | None = None
410) -> Iterator[InferenceResult]:
411 """Inference tip for an Unknown node that was dynamically generated to
412 represent a dataclass attribute.
413
414 In the case that a default value is provided, that is inferred first.
415 Then, an Instance of the annotated class is yielded.
416 """
417 assign = node.parent
418 if not isinstance(assign, nodes.AnnAssign):
419 yield Uninferable
420 return
421
422 annotation, value = assign.annotation, assign.value
423 if value is not None:
424 yield from value.infer(context=ctx)
425 if annotation is not None:
426 yield from _infer_instance_from_annotation(annotation, ctx=ctx)
427 else:
428 yield Uninferable
429
430
431def infer_dataclass_field_call(
432 node: nodes.Call, ctx: context.InferenceContext | None = None
433) -> Iterator[InferenceResult]:
434 """Inference tip for dataclass field calls."""
435 if not isinstance(node.parent, (nodes.AnnAssign, nodes.Assign)):
436 raise UseInferenceDefault
437 result = _get_field_default(node)
438 if not result:
439 yield Uninferable
440 else:
441 default_type, default = result
442 if default_type == "default":
443 yield from default.infer(context=ctx)
444 else:
445 new_call = parse(default.as_string()).body[0].value
446 new_call.parent = node.parent
447 yield from new_call.infer(context=ctx)
448
449
450def _looks_like_dataclass_decorator(
451 node: nodes.NodeNG, decorator_names: frozenset[str] = DATACLASSES_DECORATORS
452) -> bool:
453 """Return True if node looks like a dataclass decorator.
454
455 Uses inference to lookup the value of the node, and if that fails,
456 matches against specific names.
457 """
458 if isinstance(node, nodes.Call): # decorator with arguments
459 node = node.func
460 try:
461 inferred = next(node.infer())
462 except (InferenceError, StopIteration):
463 inferred = Uninferable
464
465 if isinstance(inferred, UninferableBase):
466 if isinstance(node, nodes.Name):
467 return node.name in decorator_names
468 if isinstance(node, nodes.Attribute):
469 return node.attrname in decorator_names
470
471 return False
472
473 return (
474 isinstance(inferred, nodes.FunctionDef)
475 and inferred.name in decorator_names
476 and inferred.root().name in DATACLASS_MODULES
477 )
478
479
480def _looks_like_dataclass_attribute(node: nodes.Unknown) -> bool:
481 """Return True if node was dynamically generated as the child of an AnnAssign
482 statement.
483 """
484 parent = node.parent
485 if not parent:
486 return False
487
488 scope = parent.scope()
489 return (
490 isinstance(parent, nodes.AnnAssign)
491 and isinstance(scope, nodes.ClassDef)
492 and is_decorated_with_dataclass(scope)
493 )
494
495
496def _looks_like_dataclass_field_call(
497 node: nodes.Call, check_scope: bool = True
498) -> bool:
499 """Return True if node is calling dataclasses field or Field
500 from an AnnAssign statement directly in the body of a ClassDef.
501
502 If check_scope is False, skips checking the statement and body.
503 """
504 if check_scope:
505 stmt = node.statement()
506 scope = stmt.scope()
507 if not (
508 isinstance(stmt, nodes.AnnAssign)
509 and stmt.value is not None
510 and isinstance(scope, nodes.ClassDef)
511 and is_decorated_with_dataclass(scope)
512 ):
513 return False
514
515 try:
516 inferred = next(node.func.infer())
517 except (InferenceError, StopIteration):
518 return False
519
520 if not isinstance(inferred, nodes.FunctionDef):
521 return False
522
523 return inferred.name == FIELD_NAME and inferred.root().name in DATACLASS_MODULES
524
525
526def _looks_like_dataclasses(node: nodes.Module) -> bool:
527 return node.qname() == "dataclasses"
528
529
530def _resolve_private_replace_to_public(node: nodes.Module) -> None:
531 """In python/cpython@6f3c138, a _replace() method was extracted from
532 replace(), and this indirection made replace() uninferable."""
533 if "_replace" in node.locals:
534 node.locals["replace"] = node.locals["_replace"]
535
536
537def _get_field_default(field_call: nodes.Call) -> _FieldDefaultReturn:
538 """Return a the default value of a field call, and the corresponding keyword
539 argument name.
540
541 field(default=...) results in the ... node
542 field(default_factory=...) results in a Call node with func ... and no arguments
543
544 If neither or both arguments are present, return ("", None) instead,
545 indicating that there is not a valid default value.
546 """
547 default, default_factory = None, None
548 for keyword in field_call.keywords:
549 if keyword.arg == "default":
550 default = keyword.value
551 elif keyword.arg == "default_factory":
552 default_factory = keyword.value
553
554 if default is not None and default_factory is None:
555 return "default", default
556
557 if default is None and default_factory is not None:
558 new_call = nodes.Call(
559 lineno=field_call.lineno,
560 col_offset=field_call.col_offset,
561 parent=field_call.parent,
562 end_lineno=field_call.end_lineno,
563 end_col_offset=field_call.end_col_offset,
564 )
565 new_call.postinit(func=default_factory, args=[], keywords=[])
566 return "default_factory", new_call
567
568 return None
569
570
571def _is_keyword_only_sentinel(node: nodes.NodeNG) -> bool:
572 """Return True if node is the KW_ONLY sentinel."""
573 inferred = safe_infer(node)
574 if not isinstance(inferred, bases.Instance):
575 return False
576 if inferred.qname() == "dataclasses._KW_ONLY_TYPE":
577 return True
578 if inferred.qname() != "builtins.sentinel":
579 return False
580 if isinstance(node, nodes.Name):
581 _, assignments = node.lookup(node.name)
582 return any(
583 isinstance(assignment, nodes.ImportFrom)
584 and assignment.modname == "dataclasses"
585 and any(imported == "KW_ONLY" for imported, _ in assignment.names)
586 for assignment in assignments
587 )
588 if isinstance(node, nodes.Attribute) and node.attrname == "KW_ONLY":
589 inferred_expr = safe_infer(node.expr)
590 return (
591 isinstance(inferred_expr, nodes.Module)
592 and inferred_expr.qname() == "dataclasses"
593 )
594 return False
595
596
597def _is_init_var(node: nodes.NodeNG) -> bool:
598 """Return True if node is an InitVar, with or without subscripting."""
599 try:
600 inferred = next(node.infer())
601 except (InferenceError, StopIteration):
602 return False
603
604 return getattr(inferred, "name", "") == "InitVar"
605
606
607# Allowed typing classes for which we support inferring instances
608_INFERABLE_TYPING_TYPES = frozenset(
609 (
610 "Dict",
611 "FrozenSet",
612 "List",
613 "Set",
614 "Tuple",
615 )
616)
617
618
619def _infer_instance_from_annotation(
620 node: nodes.NodeNG, ctx: context.InferenceContext | None = None
621) -> Iterator[UninferableBase | bases.Instance]:
622 """Infer an instance corresponding to the type annotation represented by node.
623
624 Currently has limited support for the typing module.
625 """
626 klass = None
627 try:
628 klass = next(node.infer(context=ctx))
629 except (InferenceError, StopIteration):
630 yield Uninferable
631 if not isinstance(klass, nodes.ClassDef):
632 yield Uninferable
633 elif klass.root().name in {
634 "typing",
635 "_collections_abc",
636 "",
637 }: # "" because of synthetic nodes in brain_typing.py
638 if klass.name in _INFERABLE_TYPING_TYPES:
639 yield klass.instantiate_class()
640 else:
641 yield Uninferable
642 else:
643 yield klass.instantiate_class()
644
645
646def register(manager: AstroidManager) -> None:
647 if PY313_PLUS:
648 manager.register_transform(
649 nodes.Module,
650 _resolve_private_replace_to_public,
651 _looks_like_dataclasses,
652 )
653
654 manager.register_transform(
655 nodes.ClassDef, dataclass_transform, is_decorated_with_dataclass
656 )
657
658 manager.register_transform(
659 nodes.Call,
660 inference_tip(infer_dataclass_field_call, raise_on_overwrite=True),
661 _looks_like_dataclass_field_call,
662 )
663
664 manager.register_transform(
665 nodes.Unknown,
666 inference_tip(infer_dataclass_attribute, raise_on_overwrite=True),
667 _looks_like_dataclass_attribute,
668 )