Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/libcst/testing/utils.py: 52%

75 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:43 +0000

1# Copyright (c) Meta Platforms, Inc. and affiliates. 

2# 

3# This source code is licensed under the MIT license found in the 

4# LICENSE file in the root directory of this source tree. 

5# pyre-unsafe 

6 

7import inspect 

8import re 

9from functools import wraps 

10from typing import ( 

11 Any, 

12 Callable, 

13 Dict, 

14 Iterable, 

15 List, 

16 Mapping, 

17 Optional, 

18 Sequence, 

19 Tuple, 

20 TypeVar, 

21 Union, 

22) 

23from unittest import TestCase 

24 

25DATA_PROVIDER_DATA_ATTR_NAME = "__data_provider_data" 

26DATA_PROVIDER_DESCRIPTION_PREFIX = "_data_provider_" 

27PROVIDER_TEST_LIMIT_ATTR_NAME = "__provider_test_limit" 

28DEFAULT_TEST_LIMIT = 256 

29 

30 

31T = TypeVar("T") 

32 

33 

34def none_throws(value: Optional[T], message: str = "Unexpected None value") -> T: 

35 assert value is not None, message 

36 return value 

37 

38 

39def update_test_limit(test_method: Any, test_limit: int) -> None: 

40 # Store the maximum number of generated tests on the test_method. Since 

41 # contextmanager_provider can be specified multiple times, we need to 

42 # take the maximum of the existing attribute and the current value 

43 existing_test_limit = getattr( 

44 test_method, PROVIDER_TEST_LIMIT_ATTR_NAME, test_limit 

45 ) 

46 setattr( 

47 test_method, PROVIDER_TEST_LIMIT_ATTR_NAME, max(existing_test_limit, test_limit) 

48 ) 

49 

50 

51def try_get_provider_attr( 

52 member_name: str, member: Any, attr_name: str 

53) -> Optional[Any]: 

54 if inspect.isfunction(member) and member_name.startswith("test"): 

55 return getattr(member, attr_name, None) 

56 return None 

57 

58 

59def populate_data_provider_tests(dct: Dict[str, Any]) -> None: 

60 test_methods_to_add: Dict[str, Callable] = {} 

61 test_methods_to_remove: List[str] = [] 

62 for member_name, member in dct.items(): 

63 provider_data = try_get_provider_attr( 

64 member_name, member, DATA_PROVIDER_DATA_ATTR_NAME 

65 ) 

66 if provider_data is not None: 

67 for description, data in ( 

68 provider_data.items() 

69 if isinstance(provider_data, dict) 

70 else enumerate(provider_data) 

71 ): 

72 if isinstance(provider_data, dict): 

73 description = f"{DATA_PROVIDER_DESCRIPTION_PREFIX}{description}" 

74 

75 assert re.fullmatch( 

76 r"[a-zA-Z0-9_]+", str(description) 

77 ), f"Testcase description must be a valid python identifier: '{description}'" 

78 

79 @wraps(member) 

80 def new_test( 

81 self: object, 

82 data: Iterable[object] = data, 

83 member: Callable[..., object] = member, 

84 ) -> object: 

85 if isinstance(data, dict): 

86 return member(self, **data) 

87 else: 

88 return member(self, *data) 

89 

90 name = f"{member_name}_{description}" 

91 new_test.__name__ = name 

92 test_methods_to_add[name] = new_test 

93 if not test_methods_to_add: 

94 raise ValueError( 

95 f"No data_provider tests were created for {member_name}! Please double check your data." 

96 ) 

97 test_methods_to_remove.append(member_name) 

98 dct.update(test_methods_to_add) 

99 

100 # Remove all old methods 

101 for test_name in test_methods_to_remove: 

102 del dct[test_name] 

103 

104 

105def validate_provider_tests(dct: Dict[str, Any]) -> None: 

106 members_to_replace = {} 

107 

108 for member_name, member in dct.items(): 

109 test_limit = try_get_provider_attr( 

110 member_name, member, PROVIDER_TEST_LIMIT_ATTR_NAME 

111 ) 

112 if test_limit is not None: 

113 data = try_get_provider_attr( 

114 member_name, member, DATA_PROVIDER_DATA_ATTR_NAME 

115 ) 

116 num_tests = len(data) if data else 1 

117 

118 if num_tests > test_limit: 

119 # We don't use wraps() here so that the test isn't expanded 

120 # as it normally would be by whichever provider it uses 

121 def test_replacement( 

122 self: Any, 

123 member_name: Any = member_name, 

124 num_tests: Any = num_tests, 

125 test_limit: Any = test_limit, 

126 ) -> None: 

127 raise AssertionError( 

128 f"{member_name} generated {num_tests} tests but the limit is " 

129 + f"{test_limit}. You can increase the number of " 

130 + "allowed tests by specifying test_limit, but please " 

131 + "consider whether you really need to test all of " 

132 + "these combinations." 

133 ) 

134 

135 setattr(test_replacement, "__name__", member_name) 

136 members_to_replace[member_name] = test_replacement 

137 

138 for member_name, new_member in members_to_replace.items(): 

139 dct[member_name] = new_member 

140 

141 

142TestCaseType = Union[Sequence[object], Mapping[str, object]] 

143# Can't use Sequence[TestCaseType] here as some clients may pass in a Generator[TestCaseType] 

144StaticDataType = Union[Iterable[TestCaseType], Mapping[str, TestCaseType]] 

145 

146 

147def data_provider( 

148 static_data: StaticDataType, *, test_limit: int = DEFAULT_TEST_LIMIT 

149) -> Callable[[Callable], Callable]: 

150 # We need to be able to iterate over static_data more than once 

151 # (for validation), so if we weren't passed in a dict, list, or tuple 

152 # then we'll just create a list from the data 

153 if not isinstance(static_data, (dict, list, tuple)): 

154 static_data = list(static_data) 

155 

156 def test_decorator(test_method: Callable) -> Callable: 

157 update_test_limit(test_method, test_limit) 

158 

159 setattr(test_method, DATA_PROVIDER_DATA_ATTR_NAME, static_data) 

160 return test_method 

161 

162 return test_decorator 

163 

164 

165class BaseTestMeta(type): 

166 def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> object: 

167 validate_provider_tests(dct) 

168 populate_data_provider_tests(dct) 

169 return super().__new__(mcs, name, bases, dict(dct)) 

170 

171 

172class UnitTest(TestCase, metaclass=BaseTestMeta): 

173 pass