1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""Hooks for nose library."""
6
7import re
8import textwrap
9
10from astroid.bases import BoundMethod
11from astroid.brain.helpers import register_module_extender
12from astroid.builder import AstroidBuilder
13from astroid.exceptions import InferenceError
14from astroid.manager import AstroidManager
15from astroid.nodes import List, Module
16
17CAPITALS = re.compile("([A-Z])")
18
19
20def _pep8(name, caps=CAPITALS):
21 return caps.sub(lambda m: "_" + m.groups()[0].lower(), name)
22
23
24def _nose_tools_functions():
25 """Get an iterator of names and bound methods."""
26 module = AstroidBuilder().string_build(
27 textwrap.dedent(
28 """
29 import unittest
30
31 class Test(unittest.TestCase):
32 pass
33 a = Test()
34 """
35 )
36 )
37 try:
38 case = next(module["a"].infer())
39 except (InferenceError, StopIteration):
40 return
41 for method in case.methods():
42 if method.name.startswith("assert") and "_" not in method.name:
43 pep8_name = _pep8(method.name)
44 yield pep8_name, BoundMethod(method, case)
45 if method.name == "assertEqual":
46 # nose also exports assert_equals.
47 yield "assert_equals", BoundMethod(method, case)
48
49
50def _nose_tools_transform(node):
51 for method_name, method in _nose_tools_functions():
52 node.locals[method_name] = [method]
53
54
55def _nose_tools_trivial_transform():
56 """Custom transform for the nose.tools module."""
57 stub = AstroidBuilder().string_build("""__all__ = []""")
58 all_entries = ["ok_", "eq_"]
59
60 for pep8_name, method in _nose_tools_functions():
61 all_entries.append(pep8_name)
62 stub[pep8_name] = method
63
64 # Update the __all__ variable, since nose.tools
65 # does this manually with .append.
66 all_assign = stub["__all__"].parent
67 all_object = List(all_entries)
68 all_object.parent = all_assign
69 all_assign.value = all_object
70 return stub
71
72
73def register(manager: AstroidManager) -> None:
74 register_module_extender(
75 manager, "nose.tools.trivial", _nose_tools_trivial_transform
76 )
77 manager.register_transform(
78 Module, _nose_tools_transform, lambda n: n.name == "nose.tools"
79 )