1"""Utilities for union (sum type) disambiguation."""
2
3from __future__ import annotations
4
5from collections import defaultdict
6from collections.abc import Mapping
7from dataclasses import MISSING
8from functools import reduce
9from operator import or_
10from typing import TYPE_CHECKING, Any, Callable, Literal, Union
11
12from attrs import NOTHING, Attribute, AttrsInstance
13
14from ._compat import (
15 NoneType,
16 adapted_fields,
17 fields_dict,
18 get_args,
19 get_origin,
20 has,
21 is_literal,
22 is_union_type,
23)
24from .gen import AttributeOverride
25
26if TYPE_CHECKING:
27 from .converters import BaseConverter
28
29__all__ = ["create_default_dis_func", "is_supported_union"]
30
31
32def is_supported_union(typ: Any) -> bool:
33 """Whether the type is a union of attrs classes or dataclasses."""
34 return is_union_type(typ) and all(
35 e is NoneType or has(get_origin(e) or e) for e in typ.__args__
36 )
37
38
39def create_default_dis_func(
40 converter: BaseConverter,
41 *classes: type[AttrsInstance],
42 use_literals: bool = True,
43 overrides: (
44 dict[str, AttributeOverride] | Literal["from_converter"]
45 ) = "from_converter",
46) -> Callable[[Mapping[Any, Any]], type[Any] | None]:
47 """Given attrs classes or dataclasses, generate a disambiguation function.
48
49 The function is based on unique fields without defaults or unique values.
50
51 :param use_literals: Whether to try using fields annotated as literals for
52 disambiguation.
53 :param overrides: Attribute overrides to apply.
54
55 .. versionchanged:: 24.1.0
56 Dataclasses are now supported.
57 """
58 if len(classes) < 2:
59 raise ValueError("At least two classes required.")
60
61 if overrides == "from_converter":
62 overrides = [
63 getattr(converter.get_structure_hook(c), "overrides", {}) for c in classes
64 ]
65 else:
66 overrides = [overrides for _ in classes]
67
68 # first, attempt for unique values
69 if use_literals:
70 # requirements for a discriminator field:
71 # (... TODO: a single fallback is OK)
72 # - it must always be enumerated
73 cls_candidates = [
74 {
75 at.name
76 for at in adapted_fields(get_origin(cl) or cl)
77 if is_literal(at.type)
78 }
79 for cl in classes
80 ]
81
82 # literal field names common to all members
83 discriminators: set[str] = cls_candidates[0]
84 for possible_discriminators in cls_candidates:
85 discriminators &= possible_discriminators
86
87 best_result = None
88 best_discriminator = None
89 for discriminator in discriminators:
90 # maps Literal values (strings, ints...) to classes
91 mapping = defaultdict(list)
92
93 for cl in classes:
94 for key in get_args(
95 fields_dict(get_origin(cl) or cl)[discriminator].type
96 ):
97 mapping[key].append(cl)
98
99 if best_result is None or max(len(v) for v in mapping.values()) <= max(
100 len(v) for v in best_result.values()
101 ):
102 best_result = mapping
103 best_discriminator = discriminator
104
105 if (
106 best_result
107 and best_discriminator
108 and max(len(v) for v in best_result.values()) != len(classes)
109 ):
110 final_mapping = {
111 k: v[0] if len(v) == 1 else Union[tuple(v)]
112 for k, v in best_result.items()
113 }
114
115 def dis_func(data: Mapping[Any, Any]) -> type | None:
116 if not isinstance(data, Mapping):
117 raise ValueError("Only input mappings are supported.")
118 return final_mapping[data[best_discriminator]]
119
120 return dis_func
121
122 # next, attempt for unique keys
123
124 # NOTE: This could just as well work with just field availability and not
125 # uniqueness, returning Unions ... it doesn't do that right now.
126 cls_and_attrs = [
127 (cl, *_usable_attribute_names(cl, override))
128 for cl, override in zip(classes, overrides)
129 ]
130 # For each class, attempt to generate a single unique required field.
131 uniq_attrs_dict: dict[str, type] = {}
132
133 # We start from classes with the largest number of unique fields
134 # so we can do easy picks first, making later picks easier.
135 cls_and_attrs.sort(key=lambda c_a: len(c_a[1]), reverse=True)
136
137 fallback = None # If none match, try this.
138
139 for cl, cl_reqs, back_map in cls_and_attrs:
140 # We do not have to consider classes we've already processed, since
141 # they will have been eliminated by the match dictionary already.
142 other_classes = [
143 c_and_a
144 for c_and_a in cls_and_attrs
145 if c_and_a[0] is not cl and c_and_a[0] not in uniq_attrs_dict.values()
146 ]
147 other_reqs = reduce(or_, (c_a[1] for c_a in other_classes), set())
148 uniq = cl_reqs - other_reqs
149
150 # We want a unique attribute with no default.
151 cl_fields = fields_dict(get_origin(cl) or cl)
152 for maybe_renamed_attr_name in uniq:
153 orig_name = back_map[maybe_renamed_attr_name]
154 if cl_fields[orig_name].default in (NOTHING, MISSING):
155 break
156 else:
157 if fallback is None:
158 fallback = cl
159 continue
160 raise TypeError(f"{cl} has no usable non-default attributes")
161 uniq_attrs_dict[maybe_renamed_attr_name] = cl
162
163 if fallback is None:
164
165 def dis_func(data: Mapping[Any, Any]) -> type[AttrsInstance] | None:
166 if not isinstance(data, Mapping):
167 raise ValueError("Only input mappings are supported")
168 for k, v in uniq_attrs_dict.items():
169 if k in data:
170 return v
171 raise ValueError("Couldn't disambiguate")
172
173 else:
174
175 def dis_func(data: Mapping[Any, Any]) -> type[AttrsInstance] | None:
176 if not isinstance(data, Mapping):
177 raise ValueError("Only input mappings are supported")
178 for k, v in uniq_attrs_dict.items():
179 if k in data:
180 return v
181 return fallback
182
183 return dis_func
184
185
186create_uniq_field_dis_func = create_default_dis_func
187
188
189def _overriden_name(at: Attribute, override: AttributeOverride | None) -> str:
190 if override is None or override.rename is None:
191 return at.name
192 return override.rename
193
194
195def _usable_attribute_names(
196 cl: type[Any], overrides: dict[str, AttributeOverride]
197) -> tuple[set[str], dict[str, str]]:
198 """Return renamed fields and a mapping to original field names."""
199 res = set()
200 mapping = {}
201
202 for at in adapted_fields(get_origin(cl) or cl):
203 res.add(n := _overriden_name(at, overrides.get(at.name)))
204 mapping[n] = at.name
205
206 return res, mapping