Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/cattrs/disambiguators.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

90 statements  

1"""Utilities for union (sum type) disambiguation.""" 

2 

3from __future__ import annotations 

4 

5from collections import defaultdict 

6from collections.abc import Mapping 

7from dataclasses import MISSING 

8from functools import reduce 

9from operator import or_ 

10from typing import TYPE_CHECKING, Any, Callable, Literal, Union, get_origin 

11 

12from attrs import NOTHING, Attribute, AttrsInstance 

13 

14from ._compat import ( 

15 NoneType, 

16 adapted_fields, 

17 fields_dict, 

18 get_args, 

19 has, 

20 is_literal, 

21 is_union_type, 

22) 

23from .gen import AttributeOverride 

24 

25if TYPE_CHECKING: 

26 from .converters import BaseConverter 

27 

28__all__ = ["create_default_dis_func", "is_supported_union"] 

29 

30 

31def is_supported_union(typ: Any) -> bool: 

32 """Whether the type is a union of attrs classes or dataclasses.""" 

33 return is_union_type(typ) and all( 

34 e is NoneType or has(get_origin(e) or e) for e in typ.__args__ 

35 ) 

36 

37 

38def create_default_dis_func( 

39 converter: BaseConverter, 

40 *classes: type[AttrsInstance], 

41 use_literals: bool = True, 

42 overrides: ( 

43 dict[str, AttributeOverride] | Literal["from_converter"] 

44 ) = "from_converter", 

45) -> Callable[[Mapping[Any, Any]], type[Any] | None]: 

46 """Given attrs classes or dataclasses, generate a disambiguation function. 

47 

48 The function is based on unique fields without defaults or unique values. 

49 

50 :param use_literals: Whether to try using fields annotated as literals for 

51 disambiguation. 

52 :param overrides: Attribute overrides to apply. 

53 

54 .. versionchanged:: 24.1.0 

55 Dataclasses are now supported. 

56 """ 

57 if len(classes) < 2: 

58 raise ValueError("At least two classes required.") 

59 

60 if overrides == "from_converter": 

61 overrides = [ 

62 getattr(converter.get_structure_hook(c), "overrides", {}) for c in classes 

63 ] 

64 else: 

65 overrides = [overrides for _ in classes] 

66 

67 # first, attempt for unique values 

68 if use_literals: 

69 # requirements for a discriminator field: 

70 # (... TODO: a single fallback is OK) 

71 # - it must always be enumerated 

72 cls_candidates = [ 

73 { 

74 at.name 

75 for at in adapted_fields(get_origin(cl) or cl) 

76 if is_literal(at.type) 

77 } 

78 for cl in classes 

79 ] 

80 

81 # literal field names common to all members 

82 discriminators: set[str] = cls_candidates[0] 

83 for possible_discriminators in cls_candidates: 

84 discriminators &= possible_discriminators 

85 

86 best_result = None 

87 best_discriminator = None 

88 for discriminator in discriminators: 

89 # maps Literal values (strings, ints...) to classes 

90 mapping = defaultdict(list) 

91 

92 for cl in classes: 

93 for key in get_args( 

94 fields_dict(get_origin(cl) or cl)[discriminator].type 

95 ): 

96 mapping[key].append(cl) 

97 

98 if best_result is None or max(len(v) for v in mapping.values()) <= max( 

99 len(v) for v in best_result.values() 

100 ): 

101 best_result = mapping 

102 best_discriminator = discriminator 

103 

104 if ( 

105 best_result 

106 and best_discriminator 

107 and max(len(v) for v in best_result.values()) != len(classes) 

108 ): 

109 final_mapping = { 

110 k: v[0] if len(v) == 1 else Union[tuple(v)] 

111 for k, v in best_result.items() 

112 } 

113 

114 def dis_func(data: Mapping[Any, Any]) -> type | None: 

115 if not isinstance(data, Mapping): 

116 raise ValueError("Only input mappings are supported.") 

117 return final_mapping[data[best_discriminator]] 

118 

119 return dis_func 

120 

121 # next, attempt for unique keys 

122 

123 # NOTE: This could just as well work with just field availability and not 

124 # uniqueness, returning Unions ... it doesn't do that right now. 

125 cls_and_attrs = [ 

126 (cl, *_usable_attribute_names(cl, override)) 

127 for cl, override in zip(classes, overrides) 

128 ] 

129 # For each class, attempt to generate a single unique required field. 

130 uniq_attrs_dict: dict[str, type] = {} 

131 

132 # We start from classes with the largest number of unique fields 

133 # so we can do easy picks first, making later picks easier. 

134 cls_and_attrs.sort(key=lambda c_a: len(c_a[1]), reverse=True) 

135 

136 fallback = None # If none match, try this. 

137 

138 for cl, cl_reqs, back_map in cls_and_attrs: 

139 # We do not have to consider classes we've already processed, since 

140 # they will have been eliminated by the match dictionary already. 

141 other_classes = [ 

142 c_and_a 

143 for c_and_a in cls_and_attrs 

144 if c_and_a[0] is not cl and c_and_a[0] not in uniq_attrs_dict.values() 

145 ] 

146 other_reqs = reduce(or_, (c_a[1] for c_a in other_classes), set()) 

147 uniq = cl_reqs - other_reqs 

148 

149 # We want a unique attribute with no default. 

150 cl_fields = fields_dict(get_origin(cl) or cl) 

151 for maybe_renamed_attr_name in uniq: 

152 orig_name = back_map[maybe_renamed_attr_name] 

153 if cl_fields[orig_name].default in (NOTHING, MISSING): 

154 break 

155 else: 

156 if fallback is None: 

157 fallback = cl 

158 continue 

159 raise TypeError(f"{cl} has no usable non-default attributes") 

160 uniq_attrs_dict[maybe_renamed_attr_name] = cl 

161 

162 if fallback is None: 

163 

164 def dis_func(data: Mapping[Any, Any]) -> type[AttrsInstance] | None: 

165 if not isinstance(data, Mapping): 

166 raise ValueError("Only input mappings are supported") 

167 for k, v in uniq_attrs_dict.items(): 

168 if k in data: 

169 return v 

170 raise ValueError("Couldn't disambiguate") 

171 

172 else: 

173 

174 def dis_func(data: Mapping[Any, Any]) -> type[AttrsInstance] | None: 

175 if not isinstance(data, Mapping): 

176 raise ValueError("Only input mappings are supported") 

177 for k, v in uniq_attrs_dict.items(): 

178 if k in data: 

179 return v 

180 return fallback 

181 

182 return dis_func 

183 

184 

185create_uniq_field_dis_func = create_default_dis_func 

186 

187 

188def _overriden_name(at: Attribute, override: AttributeOverride | None) -> str: 

189 if override is None or override.rename is None: 

190 return at.name 

191 return override.rename 

192 

193 

194def _usable_attribute_names( 

195 cl: type[Any], overrides: dict[str, AttributeOverride] 

196) -> tuple[set[str], dict[str, str]]: 

197 """Return renamed fields and a mapping to original field names.""" 

198 res = set() 

199 mapping = {} 

200 

201 for at in adapted_fields(get_origin(cl) or cl): 

202 res.add(n := _overriden_name(at, overrides.get(at.name))) 

203 mapping[n] = at.name 

204 

205 return res, mapping