Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/cattrs/disambiguators.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

90 statements  

1"""Utilities for union (sum type) disambiguation.""" 

2 

3from __future__ import annotations 

4 

5from collections import defaultdict 

6from collections.abc import Mapping 

7from dataclasses import MISSING 

8from functools import reduce 

9from operator import or_ 

10from typing import TYPE_CHECKING, Any, Callable, Literal, Union 

11 

12from attrs import NOTHING, Attribute, AttrsInstance 

13 

14from ._compat import ( 

15 NoneType, 

16 adapted_fields, 

17 fields_dict, 

18 get_args, 

19 get_origin, 

20 has, 

21 is_literal, 

22 is_union_type, 

23) 

24from .gen import AttributeOverride 

25 

26if TYPE_CHECKING: 

27 from .converters import BaseConverter 

28 

29__all__ = ["create_default_dis_func", "is_supported_union"] 

30 

31 

32def is_supported_union(typ: Any) -> bool: 

33 """Whether the type is a union of attrs classes or dataclasses.""" 

34 return is_union_type(typ) and all( 

35 e is NoneType or has(get_origin(e) or e) for e in typ.__args__ 

36 ) 

37 

38 

39def create_default_dis_func( 

40 converter: BaseConverter, 

41 *classes: type[AttrsInstance], 

42 use_literals: bool = True, 

43 overrides: ( 

44 dict[str, AttributeOverride] | Literal["from_converter"] 

45 ) = "from_converter", 

46) -> Callable[[Mapping[Any, Any]], type[Any] | None]: 

47 """Given attrs classes or dataclasses, generate a disambiguation function. 

48 

49 The function is based on unique fields without defaults or unique values. 

50 

51 :param use_literals: Whether to try using fields annotated as literals for 

52 disambiguation. 

53 :param overrides: Attribute overrides to apply. 

54 

55 .. versionchanged:: 24.1.0 

56 Dataclasses are now supported. 

57 """ 

58 if len(classes) < 2: 

59 raise ValueError("At least two classes required.") 

60 

61 if overrides == "from_converter": 

62 overrides = [ 

63 getattr(converter.get_structure_hook(c), "overrides", {}) for c in classes 

64 ] 

65 else: 

66 overrides = [overrides for _ in classes] 

67 

68 # first, attempt for unique values 

69 if use_literals: 

70 # requirements for a discriminator field: 

71 # (... TODO: a single fallback is OK) 

72 # - it must always be enumerated 

73 cls_candidates = [ 

74 { 

75 at.name 

76 for at in adapted_fields(get_origin(cl) or cl) 

77 if is_literal(at.type) 

78 } 

79 for cl in classes 

80 ] 

81 

82 # literal field names common to all members 

83 discriminators: set[str] = cls_candidates[0] 

84 for possible_discriminators in cls_candidates: 

85 discriminators &= possible_discriminators 

86 

87 best_result = None 

88 best_discriminator = None 

89 for discriminator in discriminators: 

90 # maps Literal values (strings, ints...) to classes 

91 mapping = defaultdict(list) 

92 

93 for cl in classes: 

94 for key in get_args( 

95 fields_dict(get_origin(cl) or cl)[discriminator].type 

96 ): 

97 mapping[key].append(cl) 

98 

99 if best_result is None or max(len(v) for v in mapping.values()) <= max( 

100 len(v) for v in best_result.values() 

101 ): 

102 best_result = mapping 

103 best_discriminator = discriminator 

104 

105 if ( 

106 best_result 

107 and best_discriminator 

108 and max(len(v) for v in best_result.values()) != len(classes) 

109 ): 

110 final_mapping = { 

111 k: v[0] if len(v) == 1 else Union[tuple(v)] 

112 for k, v in best_result.items() 

113 } 

114 

115 def dis_func(data: Mapping[Any, Any]) -> type | None: 

116 if not isinstance(data, Mapping): 

117 raise ValueError("Only input mappings are supported.") 

118 return final_mapping[data[best_discriminator]] 

119 

120 return dis_func 

121 

122 # next, attempt for unique keys 

123 

124 # NOTE: This could just as well work with just field availability and not 

125 # uniqueness, returning Unions ... it doesn't do that right now. 

126 cls_and_attrs = [ 

127 (cl, *_usable_attribute_names(cl, override)) 

128 for cl, override in zip(classes, overrides) 

129 ] 

130 # For each class, attempt to generate a single unique required field. 

131 uniq_attrs_dict: dict[str, type] = {} 

132 

133 # We start from classes with the largest number of unique fields 

134 # so we can do easy picks first, making later picks easier. 

135 cls_and_attrs.sort(key=lambda c_a: len(c_a[1]), reverse=True) 

136 

137 fallback = None # If none match, try this. 

138 

139 for cl, cl_reqs, back_map in cls_and_attrs: 

140 # We do not have to consider classes we've already processed, since 

141 # they will have been eliminated by the match dictionary already. 

142 other_classes = [ 

143 c_and_a 

144 for c_and_a in cls_and_attrs 

145 if c_and_a[0] is not cl and c_and_a[0] not in uniq_attrs_dict.values() 

146 ] 

147 other_reqs = reduce(or_, (c_a[1] for c_a in other_classes), set()) 

148 uniq = cl_reqs - other_reqs 

149 

150 # We want a unique attribute with no default. 

151 cl_fields = fields_dict(get_origin(cl) or cl) 

152 for maybe_renamed_attr_name in uniq: 

153 orig_name = back_map[maybe_renamed_attr_name] 

154 if cl_fields[orig_name].default in (NOTHING, MISSING): 

155 break 

156 else: 

157 if fallback is None: 

158 fallback = cl 

159 continue 

160 raise TypeError(f"{cl} has no usable non-default attributes") 

161 uniq_attrs_dict[maybe_renamed_attr_name] = cl 

162 

163 if fallback is None: 

164 

165 def dis_func(data: Mapping[Any, Any]) -> type[AttrsInstance] | None: 

166 if not isinstance(data, Mapping): 

167 raise ValueError("Only input mappings are supported") 

168 for k, v in uniq_attrs_dict.items(): 

169 if k in data: 

170 return v 

171 raise ValueError("Couldn't disambiguate") 

172 

173 else: 

174 

175 def dis_func(data: Mapping[Any, Any]) -> type[AttrsInstance] | None: 

176 if not isinstance(data, Mapping): 

177 raise ValueError("Only input mappings are supported") 

178 for k, v in uniq_attrs_dict.items(): 

179 if k in data: 

180 return v 

181 return fallback 

182 

183 return dis_func 

184 

185 

186create_uniq_field_dis_func = create_default_dis_func 

187 

188 

189def _overriden_name(at: Attribute, override: AttributeOverride | None) -> str: 

190 if override is None or override.rename is None: 

191 return at.name 

192 return override.rename 

193 

194 

195def _usable_attribute_names( 

196 cl: type[Any], overrides: dict[str, AttributeOverride] 

197) -> tuple[set[str], dict[str, str]]: 

198 """Return renamed fields and a mapping to original field names.""" 

199 res = set() 

200 mapping = {} 

201 

202 for at in adapted_fields(get_origin(cl) or cl): 

203 res.add(n := _overriden_name(at, overrides.get(at.name))) 

204 mapping[n] = at.name 

205 

206 return res, mapping