Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/libcst/testing/utils.py: 52%
75 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:43 +0000
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5# pyre-unsafe
7import inspect
8import re
9from functools import wraps
10from typing import (
11 Any,
12 Callable,
13 Dict,
14 Iterable,
15 List,
16 Mapping,
17 Optional,
18 Sequence,
19 Tuple,
20 TypeVar,
21 Union,
22)
23from unittest import TestCase
25DATA_PROVIDER_DATA_ATTR_NAME = "__data_provider_data"
26DATA_PROVIDER_DESCRIPTION_PREFIX = "_data_provider_"
27PROVIDER_TEST_LIMIT_ATTR_NAME = "__provider_test_limit"
28DEFAULT_TEST_LIMIT = 256
31T = TypeVar("T")
34def none_throws(value: Optional[T], message: str = "Unexpected None value") -> T:
35 assert value is not None, message
36 return value
39def update_test_limit(test_method: Any, test_limit: int) -> None:
40 # Store the maximum number of generated tests on the test_method. Since
41 # contextmanager_provider can be specified multiple times, we need to
42 # take the maximum of the existing attribute and the current value
43 existing_test_limit = getattr(
44 test_method, PROVIDER_TEST_LIMIT_ATTR_NAME, test_limit
45 )
46 setattr(
47 test_method, PROVIDER_TEST_LIMIT_ATTR_NAME, max(existing_test_limit, test_limit)
48 )
51def try_get_provider_attr(
52 member_name: str, member: Any, attr_name: str
53) -> Optional[Any]:
54 if inspect.isfunction(member) and member_name.startswith("test"):
55 return getattr(member, attr_name, None)
56 return None
59def populate_data_provider_tests(dct: Dict[str, Any]) -> None:
60 test_methods_to_add: Dict[str, Callable] = {}
61 test_methods_to_remove: List[str] = []
62 for member_name, member in dct.items():
63 provider_data = try_get_provider_attr(
64 member_name, member, DATA_PROVIDER_DATA_ATTR_NAME
65 )
66 if provider_data is not None:
67 for description, data in (
68 provider_data.items()
69 if isinstance(provider_data, dict)
70 else enumerate(provider_data)
71 ):
72 if isinstance(provider_data, dict):
73 description = f"{DATA_PROVIDER_DESCRIPTION_PREFIX}{description}"
75 assert re.fullmatch(
76 r"[a-zA-Z0-9_]+", str(description)
77 ), f"Testcase description must be a valid python identifier: '{description}'"
79 @wraps(member)
80 def new_test(
81 self: object,
82 data: Iterable[object] = data,
83 member: Callable[..., object] = member,
84 ) -> object:
85 if isinstance(data, dict):
86 return member(self, **data)
87 else:
88 return member(self, *data)
90 name = f"{member_name}_{description}"
91 new_test.__name__ = name
92 test_methods_to_add[name] = new_test
93 if not test_methods_to_add:
94 raise ValueError(
95 f"No data_provider tests were created for {member_name}! Please double check your data."
96 )
97 test_methods_to_remove.append(member_name)
98 dct.update(test_methods_to_add)
100 # Remove all old methods
101 for test_name in test_methods_to_remove:
102 del dct[test_name]
105def validate_provider_tests(dct: Dict[str, Any]) -> None:
106 members_to_replace = {}
108 for member_name, member in dct.items():
109 test_limit = try_get_provider_attr(
110 member_name, member, PROVIDER_TEST_LIMIT_ATTR_NAME
111 )
112 if test_limit is not None:
113 data = try_get_provider_attr(
114 member_name, member, DATA_PROVIDER_DATA_ATTR_NAME
115 )
116 num_tests = len(data) if data else 1
118 if num_tests > test_limit:
119 # We don't use wraps() here so that the test isn't expanded
120 # as it normally would be by whichever provider it uses
121 def test_replacement(
122 self: Any,
123 member_name: Any = member_name,
124 num_tests: Any = num_tests,
125 test_limit: Any = test_limit,
126 ) -> None:
127 raise AssertionError(
128 f"{member_name} generated {num_tests} tests but the limit is "
129 + f"{test_limit}. You can increase the number of "
130 + "allowed tests by specifying test_limit, but please "
131 + "consider whether you really need to test all of "
132 + "these combinations."
133 )
135 setattr(test_replacement, "__name__", member_name)
136 members_to_replace[member_name] = test_replacement
138 for member_name, new_member in members_to_replace.items():
139 dct[member_name] = new_member
142TestCaseType = Union[Sequence[object], Mapping[str, object]]
143# Can't use Sequence[TestCaseType] here as some clients may pass in a Generator[TestCaseType]
144StaticDataType = Union[Iterable[TestCaseType], Mapping[str, TestCaseType]]
147def data_provider(
148 static_data: StaticDataType, *, test_limit: int = DEFAULT_TEST_LIMIT
149) -> Callable[[Callable], Callable]:
150 # We need to be able to iterate over static_data more than once
151 # (for validation), so if we weren't passed in a dict, list, or tuple
152 # then we'll just create a list from the data
153 if not isinstance(static_data, (dict, list, tuple)):
154 static_data = list(static_data)
156 def test_decorator(test_method: Callable) -> Callable:
157 update_test_limit(test_method, test_limit)
159 setattr(test_method, DATA_PROVIDER_DATA_ATTR_NAME, static_data)
160 return test_method
162 return test_decorator
165class BaseTestMeta(type):
166 def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> object:
167 validate_provider_tests(dct)
168 populate_data_provider_tests(dct)
169 return super().__new__(mcs, name, bases, dict(dct))
172class UnitTest(TestCase, metaclass=BaseTestMeta):
173 pass