1"""Utilities for union (sum type) disambiguation."""
2
3from __future__ import annotations
4
5from collections import defaultdict
6from collections.abc import Mapping
7from dataclasses import MISSING
8from functools import reduce
9from operator import or_
10from typing import TYPE_CHECKING, Any, Callable, Literal, Union, get_origin
11
12from attrs import NOTHING, Attribute, AttrsInstance
13
14from ._compat import (
15 NoneType,
16 adapted_fields,
17 fields_dict,
18 get_args,
19 has,
20 is_literal,
21 is_union_type,
22)
23from .gen import AttributeOverride
24
25if TYPE_CHECKING:
26 from .converters import BaseConverter
27
28__all__ = ["create_default_dis_func", "is_supported_union"]
29
30
31def is_supported_union(typ: Any) -> bool:
32 """Whether the type is a union of attrs classes or dataclasses."""
33 return is_union_type(typ) and all(
34 e is NoneType or has(get_origin(e) or e) for e in typ.__args__
35 )
36
37
38def create_default_dis_func(
39 converter: BaseConverter,
40 *classes: type[AttrsInstance],
41 use_literals: bool = True,
42 overrides: (
43 dict[str, AttributeOverride] | Literal["from_converter"]
44 ) = "from_converter",
45) -> Callable[[Mapping[Any, Any]], type[Any] | None]:
46 """Given attrs classes or dataclasses, generate a disambiguation function.
47
48 The function is based on unique fields without defaults or unique values.
49
50 :param use_literals: Whether to try using fields annotated as literals for
51 disambiguation.
52 :param overrides: Attribute overrides to apply.
53
54 .. versionchanged:: 24.1.0
55 Dataclasses are now supported.
56 """
57 if len(classes) < 2:
58 raise ValueError("At least two classes required.")
59
60 if overrides == "from_converter":
61 overrides = [
62 getattr(converter.get_structure_hook(c), "overrides", {}) for c in classes
63 ]
64 else:
65 overrides = [overrides for _ in classes]
66
67 # first, attempt for unique values
68 if use_literals:
69 # requirements for a discriminator field:
70 # (... TODO: a single fallback is OK)
71 # - it must always be enumerated
72 cls_candidates = [
73 {
74 at.name
75 for at in adapted_fields(get_origin(cl) or cl)
76 if is_literal(at.type)
77 }
78 for cl in classes
79 ]
80
81 # literal field names common to all members
82 discriminators: set[str] = cls_candidates[0]
83 for possible_discriminators in cls_candidates:
84 discriminators &= possible_discriminators
85
86 best_result = None
87 best_discriminator = None
88 for discriminator in discriminators:
89 # maps Literal values (strings, ints...) to classes
90 mapping = defaultdict(list)
91
92 for cl in classes:
93 for key in get_args(
94 fields_dict(get_origin(cl) or cl)[discriminator].type
95 ):
96 mapping[key].append(cl)
97
98 if best_result is None or max(len(v) for v in mapping.values()) <= max(
99 len(v) for v in best_result.values()
100 ):
101 best_result = mapping
102 best_discriminator = discriminator
103
104 if (
105 best_result
106 and best_discriminator
107 and max(len(v) for v in best_result.values()) != len(classes)
108 ):
109 final_mapping = {
110 k: v[0] if len(v) == 1 else Union[tuple(v)]
111 for k, v in best_result.items()
112 }
113
114 def dis_func(data: Mapping[Any, Any]) -> type | None:
115 if not isinstance(data, Mapping):
116 raise ValueError("Only input mappings are supported.")
117 return final_mapping[data[best_discriminator]]
118
119 return dis_func
120
121 # next, attempt for unique keys
122
123 # NOTE: This could just as well work with just field availability and not
124 # uniqueness, returning Unions ... it doesn't do that right now.
125 cls_and_attrs = [
126 (cl, *_usable_attribute_names(cl, override))
127 for cl, override in zip(classes, overrides)
128 ]
129 # For each class, attempt to generate a single unique required field.
130 uniq_attrs_dict: dict[str, type] = {}
131
132 # We start from classes with the largest number of unique fields
133 # so we can do easy picks first, making later picks easier.
134 cls_and_attrs.sort(key=lambda c_a: len(c_a[1]), reverse=True)
135
136 fallback = None # If none match, try this.
137
138 for cl, cl_reqs, back_map in cls_and_attrs:
139 # We do not have to consider classes we've already processed, since
140 # they will have been eliminated by the match dictionary already.
141 other_classes = [
142 c_and_a
143 for c_and_a in cls_and_attrs
144 if c_and_a[0] is not cl and c_and_a[0] not in uniq_attrs_dict.values()
145 ]
146 other_reqs = reduce(or_, (c_a[1] for c_a in other_classes), set())
147 uniq = cl_reqs - other_reqs
148
149 # We want a unique attribute with no default.
150 cl_fields = fields_dict(get_origin(cl) or cl)
151 for maybe_renamed_attr_name in uniq:
152 orig_name = back_map[maybe_renamed_attr_name]
153 if cl_fields[orig_name].default in (NOTHING, MISSING):
154 break
155 else:
156 if fallback is None:
157 fallback = cl
158 continue
159 raise TypeError(f"{cl} has no usable non-default attributes")
160 uniq_attrs_dict[maybe_renamed_attr_name] = cl
161
162 if fallback is None:
163
164 def dis_func(data: Mapping[Any, Any]) -> type[AttrsInstance] | None:
165 if not isinstance(data, Mapping):
166 raise ValueError("Only input mappings are supported")
167 for k, v in uniq_attrs_dict.items():
168 if k in data:
169 return v
170 raise ValueError("Couldn't disambiguate")
171
172 else:
173
174 def dis_func(data: Mapping[Any, Any]) -> type[AttrsInstance] | None:
175 if not isinstance(data, Mapping):
176 raise ValueError("Only input mappings are supported")
177 for k, v in uniq_attrs_dict.items():
178 if k in data:
179 return v
180 return fallback
181
182 return dis_func
183
184
185create_uniq_field_dis_func = create_default_dis_func
186
187
188def _overriden_name(at: Attribute, override: AttributeOverride | None) -> str:
189 if override is None or override.rename is None:
190 return at.name
191 return override.rename
192
193
194def _usable_attribute_names(
195 cl: type[Any], overrides: dict[str, AttributeOverride]
196) -> tuple[set[str], dict[str, str]]:
197 """Return renamed fields and a mapping to original field names."""
198 res = set()
199 mapping = {}
200
201 for at in adapted_fields(get_origin(cl) or cl):
202 res.add(n := _overriden_name(at, overrides.get(at.name)))
203 mapping[n] = at.name
204
205 return res, mapping