1import warnings
2from dataclasses import dataclass
3from dataclasses import field
4from dataclasses import replace
5from enum import Enum
6from itertools import groupby
7from operator import attrgetter
8from typing import NamedTuple
9
10from wtforms import widgets
11from wtforms._compat import get_signature
12from wtforms.fields.core import Field
13from wtforms.validators import ValidationError
14
15__all__ = (
16 "SelectField",
17 "Choice",
18 "SelectChoice",
19 "SelectMultipleField",
20 "RadioField",
21)
22
23
24def _enum_coerce(enum_cls):
25 def coerce(v):
26 if isinstance(v, enum_cls):
27 return v
28 try:
29 return enum_cls[v]
30 except KeyError as e:
31 raise ValueError(str(e)) from e
32
33 return coerce
34
35
36class Choice(NamedTuple):
37 """
38 A rendered option yielded by
39 :meth:`SelectFieldBase.iter_choices` and
40 :meth:`SelectFieldBase.iter_groups`.
41
42 ``selected`` is computed against the field's current data. To
43 declare options on a :class:`SelectField`, use
44 :class:`SelectChoice` instead.
45
46 :param value:
47 The value that will be sent in the request.
48 :param label:
49 The label of the option.
50 :param selected:
51 Whether the option is currently selected. Set by ``iter_choices``;
52 you rarely set this yourself.
53 :param render_kw:
54 A dict containing HTML attributes that will be rendered
55 with the option.
56 """
57
58 value: str
59 label: str
60 selected: bool
61 render_kw: dict
62
63
64@dataclass
65class SelectChoice:
66 """
67 An option declared via :class:`SelectField` and
68 :class:`SelectMultipleField`'s ``choices=`` parameter.
69
70 :param value:
71 The value that will be sent in the request.
72 :param label:
73 The label of the option. Defaults to ``value`` when omitted.
74 :param render_kw:
75 A dict containing HTML attributes that will be rendered
76 with the option. Defaults to an empty dict when omitted.
77 :param optgroup:
78 The ``<optgroup>`` HTML tag in which the option will be rendered.
79 """
80
81 value: str
82 label: str = None # type: ignore[assignment]
83 render_kw: dict = field(default_factory=dict)
84 optgroup: str | None = None
85
86 def __post_init__(self):
87 if self.label is None:
88 self.label = self.value
89
90 def __iter__(self):
91 return iter((self.value, self.label, self.render_kw, self.optgroup))
92
93 @classmethod
94 def from_enum(cls, enum_cls, *, label=None):
95 """Build a list of choices from an :class:`enum.Enum` class.
96
97 The HTML value of each option is the item ``name``. The label
98 defaults to ``str(item)`` when the Enum defines its own
99 ``__str__``, otherwise to ``item.name``. Pass ``label=`` (a
100 callable taking an item) to override.
101 """
102 if label is None:
103 label = str if "__str__" in enum_cls.__dict__ else lambda m: m.name
104 return [cls(value=m.name, label=label(m)) for m in enum_cls]
105
106 @classmethod
107 def from_input(cls, input, optgroup=None):
108 """Coerce a value passed by the user via ``choices=...`` into a
109 :class:`SelectChoice`.
110 """
111 if isinstance(input, SelectChoice):
112 if optgroup:
113 return replace(input, optgroup=optgroup)
114 return input
115
116 if isinstance(input, Choice):
117 warnings.warn(
118 "Passing Choice to a SelectField is deprecated; Choice is the "
119 "output type returned by iter_choices(). Use SelectChoice "
120 "instead. Support for Choice as input will be removed in "
121 "WTForms 4.0.",
122 DeprecationWarning,
123 stacklevel=4,
124 )
125 return cls(
126 value=input.value,
127 label=input.label,
128 render_kw=input.render_kw,
129 optgroup=optgroup,
130 )
131
132 if isinstance(input, str):
133 return cls(value=input, optgroup=optgroup)
134
135 if isinstance(input, tuple):
136 if len(input) not in (2, 3):
137 raise ValueError(
138 f"SelectField choice tuple must have 2 or 3 elements, "
139 f"got {len(input)}"
140 )
141 return cls(*input, optgroup=optgroup)
142
143
144def _normalize_iter_choice(choice):
145 """Coerce a value yielded by :meth:`SelectFieldBase.iter_choices` or
146 :meth:`SelectFieldBase.iter_groups` into a :class:`Choice`.
147 """
148 if isinstance(choice, Choice):
149 return choice
150 if isinstance(choice, tuple):
151 warnings.warn(
152 "Yielding raw tuples from iter_choices() or iter_groups() is "
153 "deprecated; yield Choice instances instead. Will be removed "
154 "in WTForms 4.0.",
155 DeprecationWarning,
156 stacklevel=3,
157 )
158 if len(choice) == 4:
159 value, label, selected, render_kw = choice
160 elif len(choice) == 3:
161 value, label, selected = choice
162 render_kw = {}
163 else:
164 raise TypeError(
165 f"iter_choices()/iter_groups() yielded a tuple of unsupported "
166 f"length: {len(choice)}"
167 )
168 return Choice(value=value, label=label, selected=selected, render_kw=render_kw)
169 raise TypeError(
170 f"iter_choices()/iter_groups() yielded an unsupported type: "
171 f"{type(choice).__name__}"
172 )
173
174
175class SelectFieldBase(Field):
176 option_widget = widgets.Option()
177
178 """
179 Base class for fields which can be iterated to produce options.
180
181 This isn't a field, but an abstract base class for fields which want to
182 provide this functionality.
183 """
184
185 def __init__(self, label=None, validators=None, option_widget=None, **kwargs):
186 super().__init__(label, validators, **kwargs)
187
188 if option_widget is not None:
189 self.option_widget = option_widget
190
191 def iter_choices(self):
192 """Provide data for choice widget rendering.
193
194 Should yield :class:`Choice` instances.
195 """
196 raise NotImplementedError()
197
198 def _iter_choices_normalized(self):
199 """Wrap :meth:`iter_choices` to always yield :class:`Choice`."""
200 for choice in self.iter_choices():
201 yield _normalize_iter_choice(choice)
202
203 def has_groups(self):
204 """Whether the field's choices include any ``optgroup`` hint."""
205 return False
206
207 def iter_groups(self):
208 """Yield ``(group_label, [Choice, ...])`` pairs for grouped rendering."""
209 raise NotImplementedError()
210
211 def _iter_groups_normalized(self):
212 """Wrap :meth:`iter_groups` to always yield ``(name, [Choice, ...])``."""
213 for name, group in self.iter_groups():
214 yield name, [_normalize_iter_choice(c) for c in group]
215
216 def __iter__(self):
217 opts = dict(
218 widget=self.option_widget,
219 validators=self.validators,
220 name=self.name,
221 render_kw=self.render_kw,
222 _form=self._form,
223 _meta=self.meta,
224 )
225 for i, choice in enumerate(self._iter_choices_normalized()):
226 opt = self._Option(
227 id=f"{self.id}-{i}",
228 label=choice.label or choice.value,
229 **opts,
230 )
231 opt.choice = choice
232 opt.checked = choice.selected
233 opt.process(None, choice.value)
234 yield opt
235
236 def _choices_from_input(self, choices):
237 """Parse the user-supplied ``choices`` into a list of :class:`SelectChoice`."""
238 if callable(choices):
239 choices = self._invoke_choices_callback(choices)
240
241 if choices is None:
242 return None
243
244 if isinstance(choices, dict):
245 if self._is_shorthand_dict(choices):
246 result = []
247 for key, value in choices.items():
248 if isinstance(value, dict):
249 for inner_value, inner_label in value.items():
250 result.append(
251 SelectChoice(
252 value=inner_value, label=inner_label, optgroup=key
253 )
254 )
255 else:
256 result.append(SelectChoice(value=key, label=value))
257 return result
258 return [
259 SelectChoice.from_input(input, optgroup)
260 for optgroup, inputs in choices.items()
261 for input in inputs
262 ]
263
264 return [SelectChoice.from_input(input) for input in choices]
265
266 @staticmethod
267 def _is_shorthand_dict(choices):
268 """``True`` if ``choices`` matches the shorthand dict syntax.
269
270 Shorthand is ``dict[str, str | dict[str, str]]``.
271 """
272 return all(isinstance(v, (str, dict)) for v in choices.values())
273
274 @staticmethod
275 def _warn_legacy_choices(choices):
276 """Emit a one-shot ``DeprecationWarning`` for legacy ``choices`` shapes.
277
278 Legacy shapes are raw tuples or ``dict``.
279 """
280 if isinstance(choices, dict):
281 if SelectFieldBase._is_shorthand_dict(choices):
282 return
283 warnings.warn(
284 "Passing SelectField choices in a dict is deprecated and will be "
285 "removed in wtforms 3.4. Please pass a list of SelectChoice "
286 "objects with a custom optgroup attribute instead.",
287 DeprecationWarning,
288 stacklevel=3,
289 )
290 items = (i for v in choices.values() for i in v)
291 else:
292 items = choices
293 for item in items:
294 if isinstance(item, (SelectChoice, Choice)):
295 continue
296 if isinstance(item, tuple):
297 warnings.warn(
298 "Passing SelectField choices as tuples is deprecated and "
299 "will be removed in wtforms 3.4. Please use SelectChoice "
300 "instead.",
301 DeprecationWarning,
302 stacklevel=3,
303 )
304 return
305
306 def _invoke_choices_callback(self, cb):
307 try:
308 sig = get_signature(cb)
309 except (ValueError, TypeError):
310 return cb()
311 try:
312 sig.bind(self._form, self)
313 except TypeError:
314 return cb()
315 return cb(self._form, self)
316
317 class _Option(Field):
318 def _value(self):
319 return str(self.data)
320
321
322class SelectField(SelectFieldBase):
323 widget = widgets.Select()
324
325 def __init__(
326 self,
327 label=None,
328 validators=None,
329 coerce=str,
330 choices=None,
331 validate_choice=True,
332 invalid_value_message=None,
333 invalid_choice_message=None,
334 **kwargs,
335 ):
336 super().__init__(label, validators, **kwargs)
337 if isinstance(coerce, type) and issubclass(coerce, Enum):
338 coerce = _enum_coerce(coerce)
339 self.coerce = coerce
340 if callable(choices):
341 self._choices_callable = choices
342 self.choices = None
343 else:
344 self._choices_callable = None
345 if choices is None:
346 self.choices = None
347 else:
348 self._warn_legacy_choices(choices)
349 self.choices = (
350 dict(choices) if isinstance(choices, dict) else list(choices)
351 )
352 self.validate_choice = validate_choice
353 self.invalid_value_message = invalid_value_message or self.gettext(
354 "Invalid Choice: could not coerce."
355 )
356 self.invalid_choice_message = invalid_choice_message or self.gettext(
357 "Not a valid choice."
358 )
359
360 def iter_choices(self):
361 choices = self._choices_from_input(self.choices) or []
362 return [
363 Choice(
364 value=c.value,
365 label=c.label,
366 selected=self.coerce(c.value) == self.data,
367 render_kw=c.render_kw,
368 )
369 for c in choices
370 ]
371
372 def has_groups(self):
373 choices = self._choices_from_input(self.choices) or []
374 return any(c.optgroup is not None for c in choices)
375
376 def iter_groups(self):
377 choices = self._choices_from_input(self.choices) or []
378 for optgroup, group in groupby(choices, key=attrgetter("optgroup")):
379 yield (
380 optgroup,
381 [
382 Choice(
383 value=c.value,
384 label=c.label,
385 selected=self.coerce(c.value) == self.data,
386 render_kw=c.render_kw,
387 )
388 for c in group
389 ],
390 )
391
392 def post_process(self):
393 super().post_process()
394 if self._choices_callable is not None:
395 self.choices = self._invoke_choices_callback(self._choices_callable)
396
397 def process_data(self, value):
398 try:
399 # If value is None, don't coerce to a value
400 self.data = self.coerce(value) if value is not None else None
401 except (ValueError, TypeError):
402 self.data = None
403
404 def process_formdata(self, valuelist):
405 if not valuelist:
406 return
407
408 try:
409 self.data = self.coerce(valuelist[0])
410 except ValueError as exc:
411 raise ValueError(self.invalid_value_message) from exc
412
413 def pre_validate(self, form):
414 if self.process_errors:
415 return
416
417 if not self.validate_choice:
418 return
419
420 if self.choices is None:
421 raise TypeError(self.gettext("Choices cannot be None."))
422
423 if not any(choice.selected for choice in self._iter_choices_normalized()):
424 raise ValidationError(self.invalid_choice_message)
425
426
427class SelectMultipleField(SelectField):
428 """
429 No different from a normal select field, except this one can take (and
430 validate) multiple choices. You'll need to specify the HTML
431 :mdn-attr:`size` attribute on the :mdn-tag:`select` field when rendering.
432
433 ``invalid_choice_message`` may be a string, or a callable taking the
434 number of invalid submitted values and returning the message. The
435 returned message must contain ``%(value)s``, which is replaced with the
436 comma-separated list of unacceptable values.
437 """
438
439 widget = widgets.Select(multiple=True)
440
441 def __init__(
442 self,
443 label=None,
444 validators=None,
445 coerce=str,
446 choices=None,
447 validate_choice=True,
448 invalid_value_message=None,
449 invalid_choice_message=None,
450 **kwargs,
451 ):
452 super().__init__(
453 label,
454 validators,
455 coerce=coerce,
456 choices=choices,
457 validate_choice=validate_choice,
458 **kwargs,
459 )
460 self.invalid_value_message = invalid_value_message or self.gettext(
461 "Invalid choice(s): one or more data inputs could not be coerced."
462 )
463 self.invalid_choice_message = invalid_choice_message
464
465 def iter_choices(self):
466 choices = self._choices_from_input(self.choices) or []
467 data = self.data or ()
468 return [
469 Choice(
470 value=c.value,
471 label=c.label,
472 selected=self.coerce(c.value) in data,
473 render_kw=c.render_kw,
474 )
475 for c in choices
476 ]
477
478 def iter_groups(self):
479 choices = self._choices_from_input(self.choices) or []
480 data = self.data or ()
481 for optgroup, group in groupby(choices, key=attrgetter("optgroup")):
482 yield (
483 optgroup,
484 [
485 Choice(
486 value=c.value,
487 label=c.label,
488 selected=self.coerce(c.value) in data,
489 render_kw=c.render_kw,
490 )
491 for c in group
492 ],
493 )
494
495 def process_data(self, value):
496 try:
497 self.data = list(self.coerce(v) for v in value)
498 except (ValueError, TypeError):
499 self.data = None
500
501 def process_formdata(self, valuelist):
502 try:
503 self.data = list(self.coerce(x) for x in valuelist)
504 except ValueError as exc:
505 raise ValueError(self.invalid_value_message) from exc
506
507 def pre_validate(self, form):
508 if self.process_errors:
509 return
510
511 if not self.validate_choice or not self.data:
512 return
513
514 if self.choices is None:
515 raise TypeError(self.gettext("Choices cannot be None."))
516
517 acceptable = {
518 self.coerce(choice.value) for choice in self._iter_choices_normalized()
519 }
520 if any(data not in acceptable for data in self.data):
521 unacceptable = [
522 str(data) for data in set(self.data) if data not in acceptable
523 ]
524 if callable(self.invalid_choice_message):
525 message = self.invalid_choice_message(len(unacceptable))
526 elif self.invalid_choice_message is not None:
527 message = self.invalid_choice_message
528 else:
529 message = self.ngettext(
530 "'%(value)s' is not a valid choice for this field.",
531 "'%(value)s' are not valid choices for this field.",
532 len(unacceptable),
533 )
534 raise ValidationError(message % dict(value="', '".join(unacceptable)))
535
536
537class RadioField(SelectField):
538 """
539 Like a SelectField, except displays a list of :mdn-input:`radio` buttons.
540
541 Iterating the field will produce subfields (each containing a label as
542 well) in order to allow custom rendering of the individual radio fields.
543 """
544
545 widget = widgets.ListWidget(prefix_label=False)
546 option_widget = widgets.RadioInput()
547
548 def __init__(self, label=None, validators=None, **kwargs):
549 super().__init__(label, validators, **kwargs)
550 self.label.field_id = False