1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5from __future__ import annotations
6
7import inspect
8import random
9
10from astroid import nodes
11from astroid.context import InferenceContext
12from astroid.exceptions import UseInferenceDefault
13from astroid.inference_tip import inference_tip
14from astroid.manager import AstroidManager
15from astroid.util import UninferableBase, safe_infer
16
17ACCEPTED_ITERABLES_FOR_SAMPLE = (nodes.List, nodes.Set, nodes.Tuple)
18
19
20def _clone_node_with_lineno(node, parent, lineno):
21 if isinstance(node, nodes.EvaluatedObject):
22 node = node.original
23 cls = node.__class__
24 other_fields = node._other_fields
25 _astroid_fields = node._astroid_fields
26 candidate_init_params = {
27 "lineno": lineno,
28 "col_offset": node.col_offset,
29 "parent": parent,
30 "end_lineno": node.end_lineno,
31 "end_col_offset": node.end_col_offset,
32 }
33 postinit_params = {param: getattr(node, param) for param in _astroid_fields}
34
35 valid_init_params = set(inspect.signature(cls.__init__).parameters)
36 init_params = {
37 name: value
38 for name, value in candidate_init_params.items()
39 if name in valid_init_params
40 }
41 for param in other_fields:
42 if param in valid_init_params:
43 init_params[param] = getattr(node, param)
44
45 new_node = cls(**init_params)
46 if hasattr(node, "postinit") and _astroid_fields:
47 new_node.postinit(**postinit_params)
48
49 for param in other_fields:
50 if param not in valid_init_params:
51 setattr(new_node, param, getattr(node, param))
52 return new_node
53
54
55def infer_random_sample(node, context: InferenceContext | None = None):
56 if len(node.args) != 2:
57 raise UseInferenceDefault
58
59 inferred_length = safe_infer(node.args[1], context=context)
60 if not isinstance(inferred_length, nodes.Const):
61 raise UseInferenceDefault
62 if not isinstance(inferred_length.value, int):
63 raise UseInferenceDefault
64
65 inferred_sequence = safe_infer(node.args[0], context=context)
66 if not inferred_sequence:
67 raise UseInferenceDefault
68
69 if not isinstance(inferred_sequence, ACCEPTED_ITERABLES_FOR_SAMPLE):
70 raise UseInferenceDefault
71
72 if inferred_length.value > len(inferred_sequence.elts):
73 # In this case, this will raise a ValueError
74 raise UseInferenceDefault
75
76 if any(isinstance(elt, UninferableBase) for elt in inferred_sequence.elts):
77 raise UseInferenceDefault
78
79 try:
80 elts = random.sample(inferred_sequence.elts, inferred_length.value)
81 except ValueError as exc:
82 raise UseInferenceDefault from exc
83
84 new_node = nodes.List(
85 lineno=node.lineno,
86 col_offset=node.col_offset,
87 parent=node.scope(),
88 end_lineno=node.end_lineno,
89 end_col_offset=node.end_col_offset,
90 )
91 new_elts = [
92 _clone_node_with_lineno(elt, parent=new_node, lineno=new_node.lineno)
93 for elt in elts
94 ]
95 new_node.postinit(new_elts)
96 return iter((new_node,))
97
98
99def _looks_like_random_sample(node) -> bool:
100 func = node.func
101 if isinstance(func, nodes.Attribute):
102 return func.attrname == "sample"
103 if isinstance(func, nodes.Name):
104 return func.name == "sample"
105 return False
106
107
108def register(manager: AstroidManager) -> None:
109 manager.register_transform(
110 nodes.Call, inference_tip(infer_random_sample), _looks_like_random_sample
111 )