1from collections.abc import Callable, Iterable, Iterator, Mapping
2from itertools import islice, tee, zip_longest
3
4from django.utils.functional import Promise
5
6__all__ = [
7 "BaseChoiceIterator",
8 "BlankChoiceIterator",
9 "CallableChoiceIterator",
10 "flatten_choices",
11 "normalize_choices",
12]
13
14
15class BaseChoiceIterator:
16 """Base class for lazy iterators for choices."""
17
18 def __eq__(self, other):
19 if isinstance(other, Iterable):
20 return all(a == b for a, b in zip_longest(self, other, fillvalue=object()))
21 return super().__eq__(other)
22
23 def __getitem__(self, index):
24 if isinstance(index, slice) or index < 0:
25 # Suboptimally consume whole iterator to handle slices and negative
26 # indexes.
27 return list(self)[index]
28 try:
29 return next(islice(self, index, index + 1))
30 except StopIteration:
31 raise IndexError("index out of range") from None
32
33 def __iter__(self):
34 raise NotImplementedError(
35 "BaseChoiceIterator subclasses must implement __iter__()."
36 )
37
38
39class BlankChoiceIterator(BaseChoiceIterator):
40 """Iterator to lazily inject a blank choice."""
41
42 def __init__(self, choices, blank_choice):
43 self.choices = choices
44 self.blank_choice = blank_choice
45
46 def __iter__(self):
47 choices, other = tee(self.choices)
48 if not any(value in ("", None) for value, _ in flatten_choices(other)):
49 yield from self.blank_choice
50 yield from choices
51
52
53class CallableChoiceIterator(BaseChoiceIterator):
54 """Iterator to lazily normalize choices generated by a callable."""
55
56 def __init__(self, func):
57 self.func = func
58
59 def __iter__(self):
60 yield from normalize_choices(self.func())
61
62
63def flatten_choices(choices):
64 """Flatten choices by removing nested values."""
65 for value_or_group, label_or_nested in choices or ():
66 if isinstance(label_or_nested, (list, tuple)):
67 yield from label_or_nested
68 else:
69 yield value_or_group, label_or_nested
70
71
72def normalize_choices(value, *, depth=0):
73 """Normalize choices values consistently for fields and widgets."""
74 # Avoid circular import when importing django.forms.
75 from django.db.models.enums import ChoicesType
76
77 match value:
78 case BaseChoiceIterator() | Promise() | bytes() | str():
79 # Avoid prematurely normalizing iterators that should be lazy.
80 # Because string-like types are iterable, return early to avoid
81 # iterating over them in the guard for the Iterable case below.
82 return value
83 case ChoicesType():
84 # Choices enumeration helpers already output in canonical form.
85 return value.choices
86 case Mapping() if depth < 2:
87 value = value.items()
88 case Iterator() if depth < 2:
89 # Although Iterator would be handled by the Iterable case below,
90 # the iterator would be consumed prematurely while checking that
91 # its elements are not string-like in the guard, so we handle it
92 # separately.
93 pass
94 case Iterable() if depth < 2 and not any(
95 isinstance(x, (Promise, bytes, str)) for x in value
96 ):
97 # String-like types are iterable, so the guard above ensures that
98 # they're handled by the default case below.
99 pass
100 case Callable() if depth == 0:
101 # If at the top level, wrap callables to be evaluated lazily.
102 return CallableChoiceIterator(value)
103 case Callable() if depth < 2:
104 value = value()
105 case _:
106 return value
107
108 try:
109 # Recursive call to convert any nested values to a list of 2-tuples.
110 return [(k, normalize_choices(v, depth=depth + 1)) for k, v in value]
111 except (TypeError, ValueError):
112 # Return original value for the system check to raise if it has items
113 # that are not iterable or not 2-tuples:
114 # - TypeError: cannot unpack non-iterable <type> object
115 # - ValueError: <not enough / too many> values to unpack
116 return value