1from __future__ import annotations
2
3from collections.abc import MutableSet
4from copy import deepcopy
5
6from .. import exceptions
7from .._internal import _missing
8from .mixins import ImmutableDictMixin
9from .mixins import ImmutableListMixin
10from .mixins import ImmutableMultiDictMixin
11from .mixins import UpdateDictMixin
12
13
14def is_immutable(self):
15 raise TypeError(f"{type(self).__name__!r} objects are immutable")
16
17
18def iter_multi_items(mapping):
19 """Iterates over the items of a mapping yielding keys and values
20 without dropping any from more complex structures.
21 """
22 if isinstance(mapping, MultiDict):
23 yield from mapping.items(multi=True)
24 elif isinstance(mapping, dict):
25 for key, value in mapping.items():
26 if isinstance(value, (tuple, list)):
27 for v in value:
28 yield key, v
29 else:
30 yield key, value
31 else:
32 yield from mapping
33
34
35class ImmutableList(ImmutableListMixin, list):
36 """An immutable :class:`list`.
37
38 .. versionadded:: 0.5
39
40 :private:
41 """
42
43 def __repr__(self):
44 return f"{type(self).__name__}({list.__repr__(self)})"
45
46
47class TypeConversionDict(dict):
48 """Works like a regular dict but the :meth:`get` method can perform
49 type conversions. :class:`MultiDict` and :class:`CombinedMultiDict`
50 are subclasses of this class and provide the same feature.
51
52 .. versionadded:: 0.5
53 """
54
55 def get(self, key, default=None, type=None):
56 """Return the default value if the requested data doesn't exist.
57 If `type` is provided and is a callable it should convert the value,
58 return it or raise a :exc:`ValueError` if that is not possible. In
59 this case the function will return the default as if the value was not
60 found:
61
62 >>> d = TypeConversionDict(foo='42', bar='blub')
63 >>> d.get('foo', type=int)
64 42
65 >>> d.get('bar', -1, type=int)
66 -1
67
68 :param key: The key to be looked up.
69 :param default: The default value to be returned if the key can't
70 be looked up. If not further specified `None` is
71 returned.
72 :param type: A callable that is used to cast the value in the
73 :class:`MultiDict`. If a :exc:`ValueError` or a
74 :exc:`TypeError` is raised by this callable the default
75 value is returned.
76
77 .. versionchanged:: 3.0.2
78 Returns the default value on :exc:`TypeError`, too.
79 """
80 try:
81 rv = self[key]
82 except KeyError:
83 return default
84 if type is not None:
85 try:
86 rv = type(rv)
87 except (ValueError, TypeError):
88 rv = default
89 return rv
90
91
92class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict):
93 """Works like a :class:`TypeConversionDict` but does not support
94 modifications.
95
96 .. versionadded:: 0.5
97 """
98
99 def copy(self):
100 """Return a shallow mutable copy of this object. Keep in mind that
101 the standard library's :func:`copy` function is a no-op for this class
102 like for any other python immutable type (eg: :class:`tuple`).
103 """
104 return TypeConversionDict(self)
105
106 def __copy__(self):
107 return self
108
109
110class MultiDict(TypeConversionDict):
111 """A :class:`MultiDict` is a dictionary subclass customized to deal with
112 multiple values for the same key which is for example used by the parsing
113 functions in the wrappers. This is necessary because some HTML form
114 elements pass multiple values for the same key.
115
116 :class:`MultiDict` implements all standard dictionary methods.
117 Internally, it saves all values for a key as a list, but the standard dict
118 access methods will only return the first value for a key. If you want to
119 gain access to the other values, too, you have to use the `list` methods as
120 explained below.
121
122 Basic Usage:
123
124 >>> d = MultiDict([('a', 'b'), ('a', 'c')])
125 >>> d
126 MultiDict([('a', 'b'), ('a', 'c')])
127 >>> d['a']
128 'b'
129 >>> d.getlist('a')
130 ['b', 'c']
131 >>> 'a' in d
132 True
133
134 It behaves like a normal dict thus all dict functions will only return the
135 first value when multiple values for one key are found.
136
137 From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a
138 subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will
139 render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP
140 exceptions.
141
142 A :class:`MultiDict` can be constructed from an iterable of
143 ``(key, value)`` tuples, a dict, a :class:`MultiDict` or from Werkzeug 0.2
144 onwards some keyword parameters.
145
146 :param mapping: the initial value for the :class:`MultiDict`. Either a
147 regular dict, an iterable of ``(key, value)`` tuples
148 or `None`.
149 """
150
151 def __init__(self, mapping=None):
152 if isinstance(mapping, MultiDict):
153 dict.__init__(self, ((k, vs[:]) for k, vs in mapping.lists()))
154 elif isinstance(mapping, dict):
155 tmp = {}
156 for key, value in mapping.items():
157 if isinstance(value, (tuple, list)):
158 if len(value) == 0:
159 continue
160 value = list(value)
161 else:
162 value = [value]
163 tmp[key] = value
164 dict.__init__(self, tmp)
165 else:
166 tmp = {}
167 for key, value in mapping or ():
168 tmp.setdefault(key, []).append(value)
169 dict.__init__(self, tmp)
170
171 def __getstate__(self):
172 return dict(self.lists())
173
174 def __setstate__(self, value):
175 dict.clear(self)
176 dict.update(self, value)
177
178 def __iter__(self):
179 # Work around https://bugs.python.org/issue43246.
180 # (`return super().__iter__()` also works here, which makes this look
181 # even more like it should be a no-op, yet it isn't.)
182 return dict.__iter__(self)
183
184 def __getitem__(self, key):
185 """Return the first data value for this key;
186 raises KeyError if not found.
187
188 :param key: The key to be looked up.
189 :raise KeyError: if the key does not exist.
190 """
191
192 if key in self:
193 lst = dict.__getitem__(self, key)
194 if len(lst) > 0:
195 return lst[0]
196 raise exceptions.BadRequestKeyError(key)
197
198 def __setitem__(self, key, value):
199 """Like :meth:`add` but removes an existing key first.
200
201 :param key: the key for the value.
202 :param value: the value to set.
203 """
204 dict.__setitem__(self, key, [value])
205
206 def add(self, key, value):
207 """Adds a new value for the key.
208
209 .. versionadded:: 0.6
210
211 :param key: the key for the value.
212 :param value: the value to add.
213 """
214 dict.setdefault(self, key, []).append(value)
215
216 def getlist(self, key, type=None):
217 """Return the list of items for a given key. If that key is not in the
218 `MultiDict`, the return value will be an empty list. Just like `get`,
219 `getlist` accepts a `type` parameter. All items will be converted
220 with the callable defined there.
221
222 :param key: The key to be looked up.
223 :param type: A callable that is used to cast the value in the
224 :class:`MultiDict`. If a :exc:`ValueError` is raised
225 by this callable the value will be removed from the list.
226 :return: a :class:`list` of all the values for the key.
227 """
228 try:
229 rv = dict.__getitem__(self, key)
230 except KeyError:
231 return []
232 if type is None:
233 return list(rv)
234 result = []
235 for item in rv:
236 try:
237 result.append(type(item))
238 except ValueError:
239 pass
240 return result
241
242 def setlist(self, key, new_list):
243 """Remove the old values for a key and add new ones. Note that the list
244 you pass the values in will be shallow-copied before it is inserted in
245 the dictionary.
246
247 >>> d = MultiDict()
248 >>> d.setlist('foo', ['1', '2'])
249 >>> d['foo']
250 '1'
251 >>> d.getlist('foo')
252 ['1', '2']
253
254 :param key: The key for which the values are set.
255 :param new_list: An iterable with the new values for the key. Old values
256 are removed first.
257 """
258 dict.__setitem__(self, key, list(new_list))
259
260 def setdefault(self, key, default=None):
261 """Returns the value for the key if it is in the dict, otherwise it
262 returns `default` and sets that value for `key`.
263
264 :param key: The key to be looked up.
265 :param default: The default value to be returned if the key is not
266 in the dict. If not further specified it's `None`.
267 """
268 if key not in self:
269 self[key] = default
270 else:
271 default = self[key]
272 return default
273
274 def setlistdefault(self, key, default_list=None):
275 """Like `setdefault` but sets multiple values. The list returned
276 is not a copy, but the list that is actually used internally. This
277 means that you can put new values into the dict by appending items
278 to the list:
279
280 >>> d = MultiDict({"foo": 1})
281 >>> d.setlistdefault("foo").extend([2, 3])
282 >>> d.getlist("foo")
283 [1, 2, 3]
284
285 :param key: The key to be looked up.
286 :param default_list: An iterable of default values. It is either copied
287 (in case it was a list) or converted into a list
288 before returned.
289 :return: a :class:`list`
290 """
291 if key not in self:
292 default_list = list(default_list or ())
293 dict.__setitem__(self, key, default_list)
294 else:
295 default_list = dict.__getitem__(self, key)
296 return default_list
297
298 def items(self, multi=False):
299 """Return an iterator of ``(key, value)`` pairs.
300
301 :param multi: If set to `True` the iterator returned will have a pair
302 for each value of each key. Otherwise it will only
303 contain pairs for the first value of each key.
304 """
305 for key, values in dict.items(self):
306 if multi:
307 for value in values:
308 yield key, value
309 else:
310 yield key, values[0]
311
312 def lists(self):
313 """Return a iterator of ``(key, values)`` pairs, where values is the list
314 of all values associated with the key."""
315 for key, values in dict.items(self):
316 yield key, list(values)
317
318 def values(self):
319 """Returns an iterator of the first value on every key's value list."""
320 for values in dict.values(self):
321 yield values[0]
322
323 def listvalues(self):
324 """Return an iterator of all values associated with a key. Zipping
325 :meth:`keys` and this is the same as calling :meth:`lists`:
326
327 >>> d = MultiDict({"foo": [1, 2, 3]})
328 >>> zip(d.keys(), d.listvalues()) == d.lists()
329 True
330 """
331 return dict.values(self)
332
333 def copy(self):
334 """Return a shallow copy of this object."""
335 return self.__class__(self)
336
337 def deepcopy(self, memo=None):
338 """Return a deep copy of this object."""
339 return self.__class__(deepcopy(self.to_dict(flat=False), memo))
340
341 def to_dict(self, flat=True):
342 """Return the contents as regular dict. If `flat` is `True` the
343 returned dict will only have the first item present, if `flat` is
344 `False` all values will be returned as lists.
345
346 :param flat: If set to `False` the dict returned will have lists
347 with all the values in it. Otherwise it will only
348 contain the first value for each key.
349 :return: a :class:`dict`
350 """
351 if flat:
352 return dict(self.items())
353 return dict(self.lists())
354
355 def update(self, mapping):
356 """update() extends rather than replaces existing key lists:
357
358 >>> a = MultiDict({'x': 1})
359 >>> b = MultiDict({'x': 2, 'y': 3})
360 >>> a.update(b)
361 >>> a
362 MultiDict([('y', 3), ('x', 1), ('x', 2)])
363
364 If the value list for a key in ``other_dict`` is empty, no new values
365 will be added to the dict and the key will not be created:
366
367 >>> x = {'empty_list': []}
368 >>> y = MultiDict()
369 >>> y.update(x)
370 >>> y
371 MultiDict([])
372 """
373 for key, value in iter_multi_items(mapping):
374 MultiDict.add(self, key, value)
375
376 def pop(self, key, default=_missing):
377 """Pop the first item for a list on the dict. Afterwards the
378 key is removed from the dict, so additional values are discarded:
379
380 >>> d = MultiDict({"foo": [1, 2, 3]})
381 >>> d.pop("foo")
382 1
383 >>> "foo" in d
384 False
385
386 :param key: the key to pop.
387 :param default: if provided the value to return if the key was
388 not in the dictionary.
389 """
390 try:
391 lst = dict.pop(self, key)
392
393 if len(lst) == 0:
394 raise exceptions.BadRequestKeyError(key)
395
396 return lst[0]
397 except KeyError:
398 if default is not _missing:
399 return default
400
401 raise exceptions.BadRequestKeyError(key) from None
402
403 def popitem(self):
404 """Pop an item from the dict."""
405 try:
406 item = dict.popitem(self)
407
408 if len(item[1]) == 0:
409 raise exceptions.BadRequestKeyError(item[0])
410
411 return (item[0], item[1][0])
412 except KeyError as e:
413 raise exceptions.BadRequestKeyError(e.args[0]) from None
414
415 def poplist(self, key):
416 """Pop the list for a key from the dict. If the key is not in the dict
417 an empty list is returned.
418
419 .. versionchanged:: 0.5
420 If the key does no longer exist a list is returned instead of
421 raising an error.
422 """
423 return dict.pop(self, key, [])
424
425 def popitemlist(self):
426 """Pop a ``(key, list)`` tuple from the dict."""
427 try:
428 return dict.popitem(self)
429 except KeyError as e:
430 raise exceptions.BadRequestKeyError(e.args[0]) from None
431
432 def __copy__(self):
433 return self.copy()
434
435 def __deepcopy__(self, memo):
436 return self.deepcopy(memo=memo)
437
438 def __repr__(self):
439 return f"{type(self).__name__}({list(self.items(multi=True))!r})"
440
441
442class _omd_bucket:
443 """Wraps values in the :class:`OrderedMultiDict`. This makes it
444 possible to keep an order over multiple different keys. It requires
445 a lot of extra memory and slows down access a lot, but makes it
446 possible to access elements in O(1) and iterate in O(n).
447 """
448
449 __slots__ = ("prev", "key", "value", "next")
450
451 def __init__(self, omd, key, value):
452 self.prev = omd._last_bucket
453 self.key = key
454 self.value = value
455 self.next = None
456
457 if omd._first_bucket is None:
458 omd._first_bucket = self
459 if omd._last_bucket is not None:
460 omd._last_bucket.next = self
461 omd._last_bucket = self
462
463 def unlink(self, omd):
464 if self.prev:
465 self.prev.next = self.next
466 if self.next:
467 self.next.prev = self.prev
468 if omd._first_bucket is self:
469 omd._first_bucket = self.next
470 if omd._last_bucket is self:
471 omd._last_bucket = self.prev
472
473
474class OrderedMultiDict(MultiDict):
475 """Works like a regular :class:`MultiDict` but preserves the
476 order of the fields. To convert the ordered multi dict into a
477 list you can use the :meth:`items` method and pass it ``multi=True``.
478
479 In general an :class:`OrderedMultiDict` is an order of magnitude
480 slower than a :class:`MultiDict`.
481
482 .. admonition:: note
483
484 Due to a limitation in Python you cannot convert an ordered
485 multi dict into a regular dict by using ``dict(multidict)``.
486 Instead you have to use the :meth:`to_dict` method, otherwise
487 the internal bucket objects are exposed.
488 """
489
490 def __init__(self, mapping=None):
491 dict.__init__(self)
492 self._first_bucket = self._last_bucket = None
493 if mapping is not None:
494 OrderedMultiDict.update(self, mapping)
495
496 def __eq__(self, other):
497 if not isinstance(other, MultiDict):
498 return NotImplemented
499 if isinstance(other, OrderedMultiDict):
500 iter1 = iter(self.items(multi=True))
501 iter2 = iter(other.items(multi=True))
502 try:
503 for k1, v1 in iter1:
504 k2, v2 = next(iter2)
505 if k1 != k2 or v1 != v2:
506 return False
507 except StopIteration:
508 return False
509 try:
510 next(iter2)
511 except StopIteration:
512 return True
513 return False
514 if len(self) != len(other):
515 return False
516 for key, values in self.lists():
517 if other.getlist(key) != values:
518 return False
519 return True
520
521 __hash__ = None
522
523 def __reduce_ex__(self, protocol):
524 return type(self), (list(self.items(multi=True)),)
525
526 def __getstate__(self):
527 return list(self.items(multi=True))
528
529 def __setstate__(self, values):
530 dict.clear(self)
531 for key, value in values:
532 self.add(key, value)
533
534 def __getitem__(self, key):
535 if key in self:
536 return dict.__getitem__(self, key)[0].value
537 raise exceptions.BadRequestKeyError(key)
538
539 def __setitem__(self, key, value):
540 self.poplist(key)
541 self.add(key, value)
542
543 def __delitem__(self, key):
544 self.pop(key)
545
546 def keys(self):
547 return (key for key, value in self.items())
548
549 def __iter__(self):
550 return iter(self.keys())
551
552 def values(self):
553 return (value for key, value in self.items())
554
555 def items(self, multi=False):
556 ptr = self._first_bucket
557 if multi:
558 while ptr is not None:
559 yield ptr.key, ptr.value
560 ptr = ptr.next
561 else:
562 returned_keys = set()
563 while ptr is not None:
564 if ptr.key not in returned_keys:
565 returned_keys.add(ptr.key)
566 yield ptr.key, ptr.value
567 ptr = ptr.next
568
569 def lists(self):
570 returned_keys = set()
571 ptr = self._first_bucket
572 while ptr is not None:
573 if ptr.key not in returned_keys:
574 yield ptr.key, self.getlist(ptr.key)
575 returned_keys.add(ptr.key)
576 ptr = ptr.next
577
578 def listvalues(self):
579 for _key, values in self.lists():
580 yield values
581
582 def add(self, key, value):
583 dict.setdefault(self, key, []).append(_omd_bucket(self, key, value))
584
585 def getlist(self, key, type=None):
586 try:
587 rv = dict.__getitem__(self, key)
588 except KeyError:
589 return []
590 if type is None:
591 return [x.value for x in rv]
592 result = []
593 for item in rv:
594 try:
595 result.append(type(item.value))
596 except ValueError:
597 pass
598 return result
599
600 def setlist(self, key, new_list):
601 self.poplist(key)
602 for value in new_list:
603 self.add(key, value)
604
605 def setlistdefault(self, key, default_list=None):
606 raise TypeError("setlistdefault is unsupported for ordered multi dicts")
607
608 def update(self, mapping):
609 for key, value in iter_multi_items(mapping):
610 OrderedMultiDict.add(self, key, value)
611
612 def poplist(self, key):
613 buckets = dict.pop(self, key, ())
614 for bucket in buckets:
615 bucket.unlink(self)
616 return [x.value for x in buckets]
617
618 def pop(self, key, default=_missing):
619 try:
620 buckets = dict.pop(self, key)
621 except KeyError:
622 if default is not _missing:
623 return default
624
625 raise exceptions.BadRequestKeyError(key) from None
626
627 for bucket in buckets:
628 bucket.unlink(self)
629
630 return buckets[0].value
631
632 def popitem(self):
633 try:
634 key, buckets = dict.popitem(self)
635 except KeyError as e:
636 raise exceptions.BadRequestKeyError(e.args[0]) from None
637
638 for bucket in buckets:
639 bucket.unlink(self)
640
641 return key, buckets[0].value
642
643 def popitemlist(self):
644 try:
645 key, buckets = dict.popitem(self)
646 except KeyError as e:
647 raise exceptions.BadRequestKeyError(e.args[0]) from None
648
649 for bucket in buckets:
650 bucket.unlink(self)
651
652 return key, [x.value for x in buckets]
653
654
655class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict):
656 """A read only :class:`MultiDict` that you can pass multiple :class:`MultiDict`
657 instances as sequence and it will combine the return values of all wrapped
658 dicts:
659
660 >>> from werkzeug.datastructures import CombinedMultiDict, MultiDict
661 >>> post = MultiDict([('foo', 'bar')])
662 >>> get = MultiDict([('blub', 'blah')])
663 >>> combined = CombinedMultiDict([get, post])
664 >>> combined['foo']
665 'bar'
666 >>> combined['blub']
667 'blah'
668
669 This works for all read operations and will raise a `TypeError` for
670 methods that usually change data which isn't possible.
671
672 From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a
673 subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will
674 render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP
675 exceptions.
676 """
677
678 def __reduce_ex__(self, protocol):
679 return type(self), (self.dicts,)
680
681 def __init__(self, dicts=None):
682 self.dicts = list(dicts) or []
683
684 @classmethod
685 def fromkeys(cls, keys, value=None):
686 raise TypeError(f"cannot create {cls.__name__!r} instances by fromkeys")
687
688 def __getitem__(self, key):
689 for d in self.dicts:
690 if key in d:
691 return d[key]
692 raise exceptions.BadRequestKeyError(key)
693
694 def get(self, key, default=None, type=None):
695 for d in self.dicts:
696 if key in d:
697 if type is not None:
698 try:
699 return type(d[key])
700 except ValueError:
701 continue
702 return d[key]
703 return default
704
705 def getlist(self, key, type=None):
706 rv = []
707 for d in self.dicts:
708 rv.extend(d.getlist(key, type))
709 return rv
710
711 def _keys_impl(self):
712 """This function exists so __len__ can be implemented more efficiently,
713 saving one list creation from an iterator.
714 """
715 rv = set()
716 rv.update(*self.dicts)
717 return rv
718
719 def keys(self):
720 return self._keys_impl()
721
722 def __iter__(self):
723 return iter(self.keys())
724
725 def items(self, multi=False):
726 found = set()
727 for d in self.dicts:
728 for key, value in d.items(multi):
729 if multi:
730 yield key, value
731 elif key not in found:
732 found.add(key)
733 yield key, value
734
735 def values(self):
736 for _key, value in self.items():
737 yield value
738
739 def lists(self):
740 rv = {}
741 for d in self.dicts:
742 for key, values in d.lists():
743 rv.setdefault(key, []).extend(values)
744 return list(rv.items())
745
746 def listvalues(self):
747 return (x[1] for x in self.lists())
748
749 def copy(self):
750 """Return a shallow mutable copy of this object.
751
752 This returns a :class:`MultiDict` representing the data at the
753 time of copying. The copy will no longer reflect changes to the
754 wrapped dicts.
755
756 .. versionchanged:: 0.15
757 Return a mutable :class:`MultiDict`.
758 """
759 return MultiDict(self)
760
761 def to_dict(self, flat=True):
762 """Return the contents as regular dict. If `flat` is `True` the
763 returned dict will only have the first item present, if `flat` is
764 `False` all values will be returned as lists.
765
766 :param flat: If set to `False` the dict returned will have lists
767 with all the values in it. Otherwise it will only
768 contain the first item for each key.
769 :return: a :class:`dict`
770 """
771 if flat:
772 return dict(self.items())
773
774 return dict(self.lists())
775
776 def __len__(self):
777 return len(self._keys_impl())
778
779 def __contains__(self, key):
780 for d in self.dicts:
781 if key in d:
782 return True
783 return False
784
785 def __repr__(self):
786 return f"{type(self).__name__}({self.dicts!r})"
787
788
789class ImmutableDict(ImmutableDictMixin, dict):
790 """An immutable :class:`dict`.
791
792 .. versionadded:: 0.5
793 """
794
795 def __repr__(self):
796 return f"{type(self).__name__}({dict.__repr__(self)})"
797
798 def copy(self):
799 """Return a shallow mutable copy of this object. Keep in mind that
800 the standard library's :func:`copy` function is a no-op for this class
801 like for any other python immutable type (eg: :class:`tuple`).
802 """
803 return dict(self)
804
805 def __copy__(self):
806 return self
807
808
809class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict):
810 """An immutable :class:`MultiDict`.
811
812 .. versionadded:: 0.5
813 """
814
815 def copy(self):
816 """Return a shallow mutable copy of this object. Keep in mind that
817 the standard library's :func:`copy` function is a no-op for this class
818 like for any other python immutable type (eg: :class:`tuple`).
819 """
820 return MultiDict(self)
821
822 def __copy__(self):
823 return self
824
825
826class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict):
827 """An immutable :class:`OrderedMultiDict`.
828
829 .. versionadded:: 0.6
830 """
831
832 def _iter_hashitems(self):
833 return enumerate(self.items(multi=True))
834
835 def copy(self):
836 """Return a shallow mutable copy of this object. Keep in mind that
837 the standard library's :func:`copy` function is a no-op for this class
838 like for any other python immutable type (eg: :class:`tuple`).
839 """
840 return OrderedMultiDict(self)
841
842 def __copy__(self):
843 return self
844
845
846class CallbackDict(UpdateDictMixin, dict):
847 """A dict that calls a function passed every time something is changed.
848 The function is passed the dict instance.
849 """
850
851 def __init__(self, initial=None, on_update=None):
852 dict.__init__(self, initial or ())
853 self.on_update = on_update
854
855 def __repr__(self):
856 return f"<{type(self).__name__} {dict.__repr__(self)}>"
857
858
859class HeaderSet(MutableSet):
860 """Similar to the :class:`ETags` class this implements a set-like structure.
861 Unlike :class:`ETags` this is case insensitive and used for vary, allow, and
862 content-language headers.
863
864 If not constructed using the :func:`parse_set_header` function the
865 instantiation works like this:
866
867 >>> hs = HeaderSet(['foo', 'bar', 'baz'])
868 >>> hs
869 HeaderSet(['foo', 'bar', 'baz'])
870 """
871
872 def __init__(self, headers=None, on_update=None):
873 self._headers = list(headers or ())
874 self._set = {x.lower() for x in self._headers}
875 self.on_update = on_update
876
877 def add(self, header):
878 """Add a new header to the set."""
879 self.update((header,))
880
881 def remove(self, header):
882 """Remove a header from the set. This raises an :exc:`KeyError` if the
883 header is not in the set.
884
885 .. versionchanged:: 0.5
886 In older versions a :exc:`IndexError` was raised instead of a
887 :exc:`KeyError` if the object was missing.
888
889 :param header: the header to be removed.
890 """
891 key = header.lower()
892 if key not in self._set:
893 raise KeyError(header)
894 self._set.remove(key)
895 for idx, key in enumerate(self._headers):
896 if key.lower() == header:
897 del self._headers[idx]
898 break
899 if self.on_update is not None:
900 self.on_update(self)
901
902 def update(self, iterable):
903 """Add all the headers from the iterable to the set.
904
905 :param iterable: updates the set with the items from the iterable.
906 """
907 inserted_any = False
908 for header in iterable:
909 key = header.lower()
910 if key not in self._set:
911 self._headers.append(header)
912 self._set.add(key)
913 inserted_any = True
914 if inserted_any and self.on_update is not None:
915 self.on_update(self)
916
917 def discard(self, header):
918 """Like :meth:`remove` but ignores errors.
919
920 :param header: the header to be discarded.
921 """
922 try:
923 self.remove(header)
924 except KeyError:
925 pass
926
927 def find(self, header):
928 """Return the index of the header in the set or return -1 if not found.
929
930 :param header: the header to be looked up.
931 """
932 header = header.lower()
933 for idx, item in enumerate(self._headers):
934 if item.lower() == header:
935 return idx
936 return -1
937
938 def index(self, header):
939 """Return the index of the header in the set or raise an
940 :exc:`IndexError`.
941
942 :param header: the header to be looked up.
943 """
944 rv = self.find(header)
945 if rv < 0:
946 raise IndexError(header)
947 return rv
948
949 def clear(self):
950 """Clear the set."""
951 self._set.clear()
952 del self._headers[:]
953 if self.on_update is not None:
954 self.on_update(self)
955
956 def as_set(self, preserve_casing=False):
957 """Return the set as real python set type. When calling this, all
958 the items are converted to lowercase and the ordering is lost.
959
960 :param preserve_casing: if set to `True` the items in the set returned
961 will have the original case like in the
962 :class:`HeaderSet`, otherwise they will
963 be lowercase.
964 """
965 if preserve_casing:
966 return set(self._headers)
967 return set(self._set)
968
969 def to_header(self):
970 """Convert the header set into an HTTP header string."""
971 return ", ".join(map(http.quote_header_value, self._headers))
972
973 def __getitem__(self, idx):
974 return self._headers[idx]
975
976 def __delitem__(self, idx):
977 rv = self._headers.pop(idx)
978 self._set.remove(rv.lower())
979 if self.on_update is not None:
980 self.on_update(self)
981
982 def __setitem__(self, idx, value):
983 old = self._headers[idx]
984 self._set.remove(old.lower())
985 self._headers[idx] = value
986 self._set.add(value.lower())
987 if self.on_update is not None:
988 self.on_update(self)
989
990 def __contains__(self, header):
991 return header.lower() in self._set
992
993 def __len__(self):
994 return len(self._set)
995
996 def __iter__(self):
997 return iter(self._headers)
998
999 def __bool__(self):
1000 return bool(self._set)
1001
1002 def __str__(self):
1003 return self.to_header()
1004
1005 def __repr__(self):
1006 return f"{type(self).__name__}({self._headers!r})"
1007
1008
1009# circular dependencies
1010from .. import http