1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5from __future__ import annotations
6
7import random
8
9from astroid.context import InferenceContext
10from astroid.exceptions import UseInferenceDefault
11from astroid.inference_tip import inference_tip
12from astroid.manager import AstroidManager
13from astroid.nodes.node_classes import (
14 Attribute,
15 Call,
16 Const,
17 EvaluatedObject,
18 List,
19 Name,
20 Set,
21 Tuple,
22)
23from astroid.util import safe_infer
24
25ACCEPTED_ITERABLES_FOR_SAMPLE = (List, Set, Tuple)
26
27
28def _clone_node_with_lineno(node, parent, lineno):
29 if isinstance(node, EvaluatedObject):
30 node = node.original
31 cls = node.__class__
32 other_fields = node._other_fields
33 _astroid_fields = node._astroid_fields
34 init_params = {
35 "lineno": lineno,
36 "col_offset": node.col_offset,
37 "parent": parent,
38 "end_lineno": node.end_lineno,
39 "end_col_offset": node.end_col_offset,
40 }
41 postinit_params = {param: getattr(node, param) for param in _astroid_fields}
42 if other_fields:
43 init_params.update({param: getattr(node, param) for param in other_fields})
44 new_node = cls(**init_params)
45 if hasattr(node, "postinit") and _astroid_fields:
46 new_node.postinit(**postinit_params)
47 return new_node
48
49
50def infer_random_sample(node, context: InferenceContext | None = None):
51 if len(node.args) != 2:
52 raise UseInferenceDefault
53
54 inferred_length = safe_infer(node.args[1], context=context)
55 if not isinstance(inferred_length, Const):
56 raise UseInferenceDefault
57 if not isinstance(inferred_length.value, int):
58 raise UseInferenceDefault
59
60 inferred_sequence = safe_infer(node.args[0], context=context)
61 if not inferred_sequence:
62 raise UseInferenceDefault
63
64 if not isinstance(inferred_sequence, ACCEPTED_ITERABLES_FOR_SAMPLE):
65 raise UseInferenceDefault
66
67 if inferred_length.value > len(inferred_sequence.elts):
68 # In this case, this will raise a ValueError
69 raise UseInferenceDefault
70
71 try:
72 elts = random.sample(inferred_sequence.elts, inferred_length.value)
73 except ValueError as exc:
74 raise UseInferenceDefault from exc
75
76 new_node = List(
77 lineno=node.lineno,
78 col_offset=node.col_offset,
79 parent=node.scope(),
80 end_lineno=node.end_lineno,
81 end_col_offset=node.end_col_offset,
82 )
83 new_elts = [
84 _clone_node_with_lineno(elt, parent=new_node, lineno=new_node.lineno)
85 for elt in elts
86 ]
87 new_node.postinit(new_elts)
88 return iter((new_node,))
89
90
91def _looks_like_random_sample(node) -> bool:
92 func = node.func
93 if isinstance(func, Attribute):
94 return func.attrname == "sample"
95 if isinstance(func, Name):
96 return func.name == "sample"
97 return False
98
99
100def register(manager: AstroidManager) -> None:
101 manager.register_transform(
102 Call, inference_tip(infer_random_sample), _looks_like_random_sample
103 )