1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""Astroid hooks for numpy.core.multiarray module."""
6
7import functools
8
9from astroid import nodes
10from astroid.brain.brain_numpy_utils import (
11 attribute_name_looks_like_numpy_member,
12 infer_numpy_attribute,
13 infer_numpy_name,
14 member_name_looks_like_numpy_member,
15)
16from astroid.brain.helpers import register_module_extender
17from astroid.builder import parse
18from astroid.inference_tip import inference_tip
19from astroid.manager import AstroidManager
20from astroid.nodes.node_classes import Attribute, Name
21
22
23def numpy_core_multiarray_transform() -> nodes.Module:
24 return parse(
25 """
26 # different functions defined in multiarray.py
27 def inner(a, b):
28 return numpy.ndarray([0, 0])
29
30 def vdot(a, b):
31 return numpy.ndarray([0, 0])
32 """
33 )
34
35
36METHODS_TO_BE_INFERRED = {
37 "array": """def array(object, dtype=None, copy=True, order='K', subok=False, ndmin=0):
38 return numpy.ndarray([0, 0])""",
39 "dot": """def dot(a, b, out=None):
40 return numpy.ndarray([0, 0])""",
41 "empty_like": """def empty_like(a, dtype=None, order='K', subok=True):
42 return numpy.ndarray((0, 0))""",
43 "concatenate": """def concatenate(arrays, axis=None, out=None):
44 return numpy.ndarray((0, 0))""",
45 "where": """def where(condition, x=None, y=None):
46 return numpy.ndarray([0, 0])""",
47 "empty": """def empty(shape, dtype=float, order='C'):
48 return numpy.ndarray([0, 0])""",
49 "bincount": """def bincount(x, weights=None, minlength=0):
50 return numpy.ndarray([0, 0])""",
51 "busday_count": """def busday_count(
52 begindates, enddates, weekmask='1111100', holidays=[], busdaycal=None, out=None
53 ):
54 return numpy.ndarray([0, 0])""",
55 "busday_offset": """def busday_offset(
56 dates, offsets, roll='raise', weekmask='1111100', holidays=None,
57 busdaycal=None, out=None
58 ):
59 return numpy.ndarray([0, 0])""",
60 "can_cast": """def can_cast(from_, to, casting='safe'):
61 return True""",
62 "copyto": """def copyto(dst, src, casting='same_kind', where=True):
63 return None""",
64 "datetime_as_string": """def datetime_as_string(arr, unit=None, timezone='naive', casting='same_kind'):
65 return numpy.ndarray([0, 0])""",
66 "is_busday": """def is_busday(dates, weekmask='1111100', holidays=None, busdaycal=None, out=None):
67 return numpy.ndarray([0, 0])""",
68 "lexsort": """def lexsort(keys, axis=-1):
69 return numpy.ndarray([0, 0])""",
70 "may_share_memory": """def may_share_memory(a, b, max_work=None):
71 return True""",
72 # Not yet available because dtype is not yet present in those brains
73 # "min_scalar_type": """def min_scalar_type(a):
74 # return numpy.dtype('int16')""",
75 "packbits": """def packbits(a, axis=None, bitorder='big'):
76 return numpy.ndarray([0, 0])""",
77 # Not yet available because dtype is not yet present in those brains
78 # "result_type": """def result_type(*arrays_and_dtypes):
79 # return numpy.dtype('int16')""",
80 "shares_memory": """def shares_memory(a, b, max_work=None):
81 return True""",
82 "unpackbits": """def unpackbits(a, axis=None, count=None, bitorder='big'):
83 return numpy.ndarray([0, 0])""",
84 "unravel_index": """def unravel_index(indices, shape, order='C'):
85 return (numpy.ndarray([0, 0]),)""",
86 "zeros": """def zeros(shape, dtype=float, order='C'):
87 return numpy.ndarray([0, 0])""",
88}
89
90
91def register(manager: AstroidManager) -> None:
92 register_module_extender(
93 manager, "numpy.core.multiarray", numpy_core_multiarray_transform
94 )
95
96 method_names = frozenset(METHODS_TO_BE_INFERRED.keys())
97
98 manager.register_transform(
99 Attribute,
100 inference_tip(functools.partial(infer_numpy_attribute, METHODS_TO_BE_INFERRED)),
101 functools.partial(attribute_name_looks_like_numpy_member, method_names),
102 )
103 manager.register_transform(
104 Name,
105 inference_tip(functools.partial(infer_numpy_name, METHODS_TO_BE_INFERRED)),
106 functools.partial(member_name_looks_like_numpy_member, method_names),
107 )