1# util/_collections.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Collection classes and helpers."""
9
10from __future__ import absolute_import
11
12import operator
13import types
14import weakref
15
16from .compat import binary_types
17from .compat import collections_abc
18from .compat import itertools_filterfalse
19from .compat import py2k
20from .compat import py37
21from .compat import string_types
22from .compat import threading
23
24
25EMPTY_SET = frozenset()
26
27
28class ImmutableContainer(object):
29 def _immutable(self, *arg, **kw):
30 raise TypeError("%s object is immutable" % self.__class__.__name__)
31
32 __delitem__ = __setitem__ = __setattr__ = _immutable
33
34
35def _immutabledict_py_fallback():
36 class immutabledict(ImmutableContainer, dict):
37
38 clear = (
39 pop
40 ) = popitem = setdefault = update = ImmutableContainer._immutable
41
42 def __new__(cls, *args):
43 new = dict.__new__(cls)
44 dict.__init__(new, *args)
45 return new
46
47 def __init__(self, *args):
48 pass
49
50 def __reduce__(self):
51 return _immutabledict_reconstructor, (dict(self),)
52
53 def union(self, __d=None):
54 if not __d:
55 return self
56
57 new = dict.__new__(self.__class__)
58 dict.__init__(new, self)
59 dict.update(new, __d)
60 return new
61
62 def _union_w_kw(self, __d=None, **kw):
63 # not sure if C version works correctly w/ this yet
64 if not __d and not kw:
65 return self
66
67 new = dict.__new__(self.__class__)
68 dict.__init__(new, self)
69 if __d:
70 dict.update(new, __d)
71 dict.update(new, kw)
72 return new
73
74 def merge_with(self, *dicts):
75 new = None
76 for d in dicts:
77 if d:
78 if new is None:
79 new = dict.__new__(self.__class__)
80 dict.__init__(new, self)
81 dict.update(new, d)
82 if new is None:
83 return self
84
85 return new
86
87 def __repr__(self):
88 return "immutabledict(%s)" % dict.__repr__(self)
89
90 return immutabledict
91
92
93try:
94 from sqlalchemy.cimmutabledict import immutabledict
95
96 collections_abc.Mapping.register(immutabledict)
97
98except ImportError:
99 immutabledict = _immutabledict_py_fallback()
100
101 def _immutabledict_reconstructor(*arg):
102 """do the pickle dance"""
103 return immutabledict(*arg)
104
105
106def coerce_to_immutabledict(d):
107 if not d:
108 return EMPTY_DICT
109 elif isinstance(d, immutabledict):
110 return d
111 else:
112 return immutabledict(d)
113
114
115EMPTY_DICT = immutabledict()
116
117
118class FacadeDict(ImmutableContainer, dict):
119 """A dictionary that is not publicly mutable."""
120
121 clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
122
123 def __new__(cls, *args):
124 new = dict.__new__(cls)
125 return new
126
127 def copy(self):
128 raise NotImplementedError(
129 "an immutabledict shouldn't need to be copied. use dict(d) "
130 "if you need a mutable dictionary."
131 )
132
133 def __reduce__(self):
134 return FacadeDict, (dict(self),)
135
136 def _insert_item(self, key, value):
137 """insert an item into the dictionary directly."""
138 dict.__setitem__(self, key, value)
139
140 def __repr__(self):
141 return "FacadeDict(%s)" % dict.__repr__(self)
142
143
144class Properties(object):
145 """Provide a __getattr__/__setattr__ interface over a dict."""
146
147 __slots__ = ("_data",)
148
149 def __init__(self, data):
150 object.__setattr__(self, "_data", data)
151
152 def __len__(self):
153 return len(self._data)
154
155 def __iter__(self):
156 return iter(list(self._data.values()))
157
158 def __dir__(self):
159 return dir(super(Properties, self)) + [
160 str(k) for k in self._data.keys()
161 ]
162
163 def __add__(self, other):
164 return list(self) + list(other)
165
166 def __setitem__(self, key, obj):
167 self._data[key] = obj
168
169 def __getitem__(self, key):
170 return self._data[key]
171
172 def __delitem__(self, key):
173 del self._data[key]
174
175 def __setattr__(self, key, obj):
176 self._data[key] = obj
177
178 def __getstate__(self):
179 return {"_data": self._data}
180
181 def __setstate__(self, state):
182 object.__setattr__(self, "_data", state["_data"])
183
184 def __getattr__(self, key):
185 try:
186 return self._data[key]
187 except KeyError:
188 raise AttributeError(key)
189
190 def __contains__(self, key):
191 return key in self._data
192
193 def as_immutable(self):
194 """Return an immutable proxy for this :class:`.Properties`."""
195
196 return ImmutableProperties(self._data)
197
198 def update(self, value):
199 self._data.update(value)
200
201 def get(self, key, default=None):
202 if key in self:
203 return self[key]
204 else:
205 return default
206
207 def keys(self):
208 return list(self._data)
209
210 def values(self):
211 return list(self._data.values())
212
213 def items(self):
214 return list(self._data.items())
215
216 def has_key(self, key):
217 return key in self._data
218
219 def clear(self):
220 self._data.clear()
221
222
223class OrderedProperties(Properties):
224 """Provide a __getattr__/__setattr__ interface with an OrderedDict
225 as backing store."""
226
227 __slots__ = ()
228
229 def __init__(self):
230 Properties.__init__(self, OrderedDict())
231
232
233class ImmutableProperties(ImmutableContainer, Properties):
234 """Provide immutable dict/object attribute to an underlying dictionary."""
235
236 __slots__ = ()
237
238
239def _ordered_dictionary_sort(d, key=None):
240 """Sort an OrderedDict in-place."""
241
242 items = [(k, d[k]) for k in sorted(d, key=key)]
243
244 d.clear()
245
246 d.update(items)
247
248
249if py37:
250 OrderedDict = dict
251 sort_dictionary = _ordered_dictionary_sort
252
253else:
254 # prevent sort_dictionary from being used against a plain dictionary
255 # for Python < 3.7
256
257 def sort_dictionary(d, key=None):
258 """Sort an OrderedDict in place."""
259
260 d._ordered_dictionary_sort(key=key)
261
262 class OrderedDict(dict):
263 """Dictionary that maintains insertion order.
264
265 Superseded by Python dict as of Python 3.7
266
267 """
268
269 __slots__ = ("_list",)
270
271 def _ordered_dictionary_sort(self, key=None):
272 _ordered_dictionary_sort(self, key=key)
273
274 def __reduce__(self):
275 return OrderedDict, (self.items(),)
276
277 def __init__(self, ____sequence=None, **kwargs):
278 self._list = []
279 if ____sequence is None:
280 if kwargs:
281 self.update(**kwargs)
282 else:
283 self.update(____sequence, **kwargs)
284
285 def clear(self):
286 self._list = []
287 dict.clear(self)
288
289 def copy(self):
290 return self.__copy__()
291
292 def __copy__(self):
293 return OrderedDict(self)
294
295 def update(self, ____sequence=None, **kwargs):
296 if ____sequence is not None:
297 if hasattr(____sequence, "keys"):
298 for key in ____sequence.keys():
299 self.__setitem__(key, ____sequence[key])
300 else:
301 for key, value in ____sequence:
302 self[key] = value
303 if kwargs:
304 self.update(kwargs)
305
306 def setdefault(self, key, value):
307 if key not in self:
308 self.__setitem__(key, value)
309 return value
310 else:
311 return self.__getitem__(key)
312
313 def __iter__(self):
314 return iter(self._list)
315
316 def keys(self):
317 return list(self)
318
319 def values(self):
320 return [self[key] for key in self._list]
321
322 def items(self):
323 return [(key, self[key]) for key in self._list]
324
325 if py2k:
326
327 def itervalues(self):
328 return iter(self.values())
329
330 def iterkeys(self):
331 return iter(self)
332
333 def iteritems(self):
334 return iter(self.items())
335
336 def __setitem__(self, key, obj):
337 if key not in self:
338 try:
339 self._list.append(key)
340 except AttributeError:
341 # work around Python pickle loads() with
342 # dict subclass (seems to ignore __setstate__?)
343 self._list = [key]
344 dict.__setitem__(self, key, obj)
345
346 def __delitem__(self, key):
347 dict.__delitem__(self, key)
348 self._list.remove(key)
349
350 def pop(self, key, *default):
351 present = key in self
352 value = dict.pop(self, key, *default)
353 if present:
354 self._list.remove(key)
355 return value
356
357 def popitem(self):
358 item = dict.popitem(self)
359 self._list.remove(item[0])
360 return item
361
362
363class OrderedSet(set):
364 def __init__(self, d=None):
365 set.__init__(self)
366 if d is not None:
367 self._list = unique_list(d)
368 set.update(self, self._list)
369 else:
370 self._list = []
371
372 def add(self, element):
373 if element not in self:
374 self._list.append(element)
375 set.add(self, element)
376
377 def remove(self, element):
378 set.remove(self, element)
379 self._list.remove(element)
380
381 def insert(self, pos, element):
382 if element not in self:
383 self._list.insert(pos, element)
384 set.add(self, element)
385
386 def discard(self, element):
387 if element in self:
388 self._list.remove(element)
389 set.remove(self, element)
390
391 def clear(self):
392 set.clear(self)
393 self._list = []
394
395 def __getitem__(self, key):
396 return self._list[key]
397
398 def __iter__(self):
399 return iter(self._list)
400
401 def __add__(self, other):
402 return self.union(other)
403
404 def __repr__(self):
405 return "%s(%r)" % (self.__class__.__name__, self._list)
406
407 __str__ = __repr__
408
409 def update(self, iterable):
410 for e in iterable:
411 if e not in self:
412 self._list.append(e)
413 set.add(self, e)
414 return self
415
416 __ior__ = update
417
418 def union(self, other):
419 result = self.__class__(self)
420 result.update(other)
421 return result
422
423 __or__ = union
424
425 def intersection(self, other):
426 other = set(other)
427 return self.__class__(a for a in self if a in other)
428
429 __and__ = intersection
430
431 def symmetric_difference(self, other):
432 other = set(other)
433 result = self.__class__(a for a in self if a not in other)
434 result.update(a for a in other if a not in self)
435 return result
436
437 __xor__ = symmetric_difference
438
439 def difference(self, other):
440 other = set(other)
441 return self.__class__(a for a in self if a not in other)
442
443 __sub__ = difference
444
445 def intersection_update(self, other):
446 other = set(other)
447 set.intersection_update(self, other)
448 self._list = [a for a in self._list if a in other]
449 return self
450
451 __iand__ = intersection_update
452
453 def symmetric_difference_update(self, other):
454 set.symmetric_difference_update(self, other)
455 self._list = [a for a in self._list if a in self]
456 self._list += [a for a in other._list if a in self]
457 return self
458
459 __ixor__ = symmetric_difference_update
460
461 def difference_update(self, other):
462 set.difference_update(self, other)
463 self._list = [a for a in self._list if a in self]
464 return self
465
466 __isub__ = difference_update
467
468
469class IdentitySet(object):
470 """A set that considers only object id() for uniqueness.
471
472 This strategy has edge cases for builtin types- it's possible to have
473 two 'foo' strings in one of these sets, for example. Use sparingly.
474
475 """
476
477 def __init__(self, iterable=None):
478 self._members = dict()
479 if iterable:
480 self.update(iterable)
481
482 def add(self, value):
483 self._members[id(value)] = value
484
485 def __contains__(self, value):
486 return id(value) in self._members
487
488 def remove(self, value):
489 del self._members[id(value)]
490
491 def discard(self, value):
492 try:
493 self.remove(value)
494 except KeyError:
495 pass
496
497 def pop(self):
498 try:
499 pair = self._members.popitem()
500 return pair[1]
501 except KeyError:
502 raise KeyError("pop from an empty set")
503
504 def clear(self):
505 self._members.clear()
506
507 def __cmp__(self, other):
508 raise TypeError("cannot compare sets using cmp()")
509
510 def __eq__(self, other):
511 if isinstance(other, IdentitySet):
512 return self._members == other._members
513 else:
514 return False
515
516 def __ne__(self, other):
517 if isinstance(other, IdentitySet):
518 return self._members != other._members
519 else:
520 return True
521
522 def issubset(self, iterable):
523 if isinstance(iterable, self.__class__):
524 other = iterable
525 else:
526 other = self.__class__(iterable)
527
528 if len(self) > len(other):
529 return False
530 for m in itertools_filterfalse(
531 other._members.__contains__, iter(self._members.keys())
532 ):
533 return False
534 return True
535
536 def __le__(self, other):
537 if not isinstance(other, IdentitySet):
538 return NotImplemented
539 return self.issubset(other)
540
541 def __lt__(self, other):
542 if not isinstance(other, IdentitySet):
543 return NotImplemented
544 return len(self) < len(other) and self.issubset(other)
545
546 def issuperset(self, iterable):
547 if isinstance(iterable, self.__class__):
548 other = iterable
549 else:
550 other = self.__class__(iterable)
551
552 if len(self) < len(other):
553 return False
554
555 for m in itertools_filterfalse(
556 self._members.__contains__, iter(other._members.keys())
557 ):
558 return False
559 return True
560
561 def __ge__(self, other):
562 if not isinstance(other, IdentitySet):
563 return NotImplemented
564 return self.issuperset(other)
565
566 def __gt__(self, other):
567 if not isinstance(other, IdentitySet):
568 return NotImplemented
569 return len(self) > len(other) and self.issuperset(other)
570
571 def union(self, iterable):
572 result = self.__class__()
573 members = self._members
574 result._members.update(members)
575 result._members.update((id(obj), obj) for obj in iterable)
576 return result
577
578 def __or__(self, other):
579 if not isinstance(other, IdentitySet):
580 return NotImplemented
581 return self.union(other)
582
583 def update(self, iterable):
584 self._members.update((id(obj), obj) for obj in iterable)
585
586 def __ior__(self, other):
587 if not isinstance(other, IdentitySet):
588 return NotImplemented
589 self.update(other)
590 return self
591
592 def difference(self, iterable):
593 result = self.__class__()
594 members = self._members
595 if isinstance(iterable, self.__class__):
596 other = set(iterable._members.keys())
597 else:
598 other = {id(obj) for obj in iterable}
599 result._members.update(
600 ((k, v) for k, v in members.items() if k not in other)
601 )
602 return result
603
604 def __sub__(self, other):
605 if not isinstance(other, IdentitySet):
606 return NotImplemented
607 return self.difference(other)
608
609 def difference_update(self, iterable):
610 self._members = self.difference(iterable)._members
611
612 def __isub__(self, other):
613 if not isinstance(other, IdentitySet):
614 return NotImplemented
615 self.difference_update(other)
616 return self
617
618 def intersection(self, iterable):
619 result = self.__class__()
620 members = self._members
621 if isinstance(iterable, self.__class__):
622 other = set(iterable._members.keys())
623 else:
624 other = {id(obj) for obj in iterable}
625 result._members.update(
626 (k, v) for k, v in members.items() if k in other
627 )
628 return result
629
630 def __and__(self, other):
631 if not isinstance(other, IdentitySet):
632 return NotImplemented
633 return self.intersection(other)
634
635 def intersection_update(self, iterable):
636 self._members = self.intersection(iterable)._members
637
638 def __iand__(self, other):
639 if not isinstance(other, IdentitySet):
640 return NotImplemented
641 self.intersection_update(other)
642 return self
643
644 def symmetric_difference(self, iterable):
645 result = self.__class__()
646 members = self._members
647 if isinstance(iterable, self.__class__):
648 other = iterable._members
649 else:
650 other = {id(obj): obj for obj in iterable}
651 result._members.update(
652 ((k, v) for k, v in members.items() if k not in other)
653 )
654 result._members.update(
655 ((k, v) for k, v in other.items() if k not in members)
656 )
657 return result
658
659 def __xor__(self, other):
660 if not isinstance(other, IdentitySet):
661 return NotImplemented
662 return self.symmetric_difference(other)
663
664 def symmetric_difference_update(self, iterable):
665 self._members = self.symmetric_difference(iterable)._members
666
667 def __ixor__(self, other):
668 if not isinstance(other, IdentitySet):
669 return NotImplemented
670 self.symmetric_difference(other)
671 return self
672
673 def copy(self):
674 return type(self)(iter(self._members.values()))
675
676 __copy__ = copy
677
678 def __len__(self):
679 return len(self._members)
680
681 def __iter__(self):
682 return iter(self._members.values())
683
684 def __hash__(self):
685 raise TypeError("set objects are unhashable")
686
687 def __repr__(self):
688 return "%s(%r)" % (type(self).__name__, list(self._members.values()))
689
690
691class WeakSequence(object):
692 def __init__(self, __elements=()):
693 # adapted from weakref.WeakKeyDictionary, prevent reference
694 # cycles in the collection itself
695 def _remove(item, selfref=weakref.ref(self)):
696 self = selfref()
697 if self is not None:
698 self._storage.remove(item)
699
700 self._remove = _remove
701 self._storage = [
702 weakref.ref(element, _remove) for element in __elements
703 ]
704
705 def append(self, item):
706 self._storage.append(weakref.ref(item, self._remove))
707
708 def __len__(self):
709 return len(self._storage)
710
711 def __iter__(self):
712 return (
713 obj for obj in (ref() for ref in self._storage) if obj is not None
714 )
715
716 def __getitem__(self, index):
717 try:
718 obj = self._storage[index]
719 except KeyError:
720 raise IndexError("Index %s out of range" % index)
721 else:
722 return obj()
723
724
725class OrderedIdentitySet(IdentitySet):
726 def __init__(self, iterable=None):
727 IdentitySet.__init__(self)
728 self._members = OrderedDict()
729 if iterable:
730 for o in iterable:
731 self.add(o)
732
733
734class PopulateDict(dict):
735 """A dict which populates missing values via a creation function.
736
737 Note the creation function takes a key, unlike
738 collections.defaultdict.
739
740 """
741
742 def __init__(self, creator):
743 self.creator = creator
744
745 def __missing__(self, key):
746 self[key] = val = self.creator(key)
747 return val
748
749
750class WeakPopulateDict(dict):
751 """Like PopulateDict, but assumes a self + a method and does not create
752 a reference cycle.
753
754 """
755
756 def __init__(self, creator_method):
757 self.creator = creator_method.__func__
758 weakself = creator_method.__self__
759 self.weakself = weakref.ref(weakself)
760
761 def __missing__(self, key):
762 self[key] = val = self.creator(self.weakself(), key)
763 return val
764
765
766# Define collections that are capable of storing
767# ColumnElement objects as hashable keys/elements.
768# At this point, these are mostly historical, things
769# used to be more complicated.
770column_set = set
771column_dict = dict
772ordered_column_set = OrderedSet
773
774
775_getters = PopulateDict(operator.itemgetter)
776
777_property_getters = PopulateDict(
778 lambda idx: property(operator.itemgetter(idx))
779)
780
781
782def unique_list(seq, hashfunc=None):
783 seen = set()
784 seen_add = seen.add
785 if not hashfunc:
786 return [x for x in seq if x not in seen and not seen_add(x)]
787 else:
788 return [
789 x
790 for x in seq
791 if hashfunc(x) not in seen and not seen_add(hashfunc(x))
792 ]
793
794
795class UniqueAppender(object):
796 """Appends items to a collection ensuring uniqueness.
797
798 Additional appends() of the same object are ignored. Membership is
799 determined by identity (``is a``) not equality (``==``).
800 """
801
802 def __init__(self, data, via=None):
803 self.data = data
804 self._unique = {}
805 if via:
806 self._data_appender = getattr(data, via)
807 elif hasattr(data, "append"):
808 self._data_appender = data.append
809 elif hasattr(data, "add"):
810 self._data_appender = data.add
811
812 def append(self, item):
813 id_ = id(item)
814 if id_ not in self._unique:
815 self._data_appender(item)
816 self._unique[id_] = True
817
818 def __iter__(self):
819 return iter(self.data)
820
821
822def coerce_generator_arg(arg):
823 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
824 return list(arg[0])
825 else:
826 return arg
827
828
829def to_list(x, default=None):
830 if x is None:
831 return default
832 if not isinstance(x, collections_abc.Iterable) or isinstance(
833 x, string_types + binary_types
834 ):
835 return [x]
836 elif isinstance(x, list):
837 return x
838 else:
839 return list(x)
840
841
842def has_intersection(set_, iterable):
843 r"""return True if any items of set\_ are present in iterable.
844
845 Goes through special effort to ensure __hash__ is not called
846 on items in iterable that don't support it.
847
848 """
849 # TODO: optimize, write in C, etc.
850 return bool(set_.intersection([i for i in iterable if i.__hash__]))
851
852
853def to_set(x):
854 if x is None:
855 return set()
856 if not isinstance(x, set):
857 return set(to_list(x))
858 else:
859 return x
860
861
862def to_column_set(x):
863 if x is None:
864 return column_set()
865 if not isinstance(x, column_set):
866 return column_set(to_list(x))
867 else:
868 return x
869
870
871def update_copy(d, _new=None, **kw):
872 """Copy the given dict and update with the given values."""
873
874 d = d.copy()
875 if _new:
876 d.update(_new)
877 d.update(**kw)
878 return d
879
880
881def flatten_iterator(x):
882 """Given an iterator of which further sub-elements may also be
883 iterators, flatten the sub-elements into a single iterator.
884
885 """
886 for elem in x:
887 if not isinstance(elem, str) and hasattr(elem, "__iter__"):
888 for y in flatten_iterator(elem):
889 yield y
890 else:
891 yield elem
892
893
894class LRUCache(dict):
895 """Dictionary with 'squishy' removal of least
896 recently used items.
897
898 Note that either get() or [] should be used here, but
899 generally its not safe to do an "in" check first as the dictionary
900 can change subsequent to that call.
901
902 """
903
904 __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex"
905
906 def __init__(self, capacity=100, threshold=0.5, size_alert=None):
907 self.capacity = capacity
908 self.threshold = threshold
909 self.size_alert = size_alert
910 self._counter = 0
911 self._mutex = threading.Lock()
912
913 def _inc_counter(self):
914 self._counter += 1
915 return self._counter
916
917 def get(self, key, default=None):
918 item = dict.get(self, key, default)
919 if item is not default:
920 item[2] = self._inc_counter()
921 return item[1]
922 else:
923 return default
924
925 def __getitem__(self, key):
926 item = dict.__getitem__(self, key)
927 item[2] = self._inc_counter()
928 return item[1]
929
930 def values(self):
931 return [i[1] for i in dict.values(self)]
932
933 def setdefault(self, key, value):
934 if key in self:
935 return self[key]
936 else:
937 self[key] = value
938 return value
939
940 def __setitem__(self, key, value):
941 item = dict.get(self, key)
942 if item is None:
943 item = [key, value, self._inc_counter()]
944 dict.__setitem__(self, key, item)
945 else:
946 item[1] = value
947 self._manage_size()
948
949 @property
950 def size_threshold(self):
951 return self.capacity + self.capacity * self.threshold
952
953 def _manage_size(self):
954 if not self._mutex.acquire(False):
955 return
956 try:
957 size_alert = bool(self.size_alert)
958 while len(self) > self.capacity + self.capacity * self.threshold:
959 if size_alert:
960 size_alert = False
961 self.size_alert(self)
962 by_counter = sorted(
963 dict.values(self), key=operator.itemgetter(2), reverse=True
964 )
965 for item in by_counter[self.capacity :]:
966 try:
967 del self[item[0]]
968 except KeyError:
969 # deleted elsewhere; skip
970 continue
971 finally:
972 self._mutex.release()
973
974
975class ScopedRegistry(object):
976 """A Registry that can store one or multiple instances of a single
977 class on the basis of a "scope" function.
978
979 The object implements ``__call__`` as the "getter", so by
980 calling ``myregistry()`` the contained object is returned
981 for the current scope.
982
983 :param createfunc:
984 a callable that returns a new object to be placed in the registry
985
986 :param scopefunc:
987 a callable that will return a key to store/retrieve an object.
988 """
989
990 def __init__(self, createfunc, scopefunc):
991 """Construct a new :class:`.ScopedRegistry`.
992
993 :param createfunc: A creation function that will generate
994 a new value for the current scope, if none is present.
995
996 :param scopefunc: A function that returns a hashable
997 token representing the current scope (such as, current
998 thread identifier).
999
1000 """
1001 self.createfunc = createfunc
1002 self.scopefunc = scopefunc
1003 self.registry = {}
1004
1005 def __call__(self):
1006 key = self.scopefunc()
1007 try:
1008 return self.registry[key]
1009 except KeyError:
1010 return self.registry.setdefault(key, self.createfunc())
1011
1012 def has(self):
1013 """Return True if an object is present in the current scope."""
1014
1015 return self.scopefunc() in self.registry
1016
1017 def set(self, obj):
1018 """Set the value for the current scope."""
1019
1020 self.registry[self.scopefunc()] = obj
1021
1022 def clear(self):
1023 """Clear the current scope, if any."""
1024
1025 try:
1026 del self.registry[self.scopefunc()]
1027 except KeyError:
1028 pass
1029
1030
1031class ThreadLocalRegistry(ScopedRegistry):
1032 """A :class:`.ScopedRegistry` that uses a ``threading.local()``
1033 variable for storage.
1034
1035 """
1036
1037 def __init__(self, createfunc):
1038 self.createfunc = createfunc
1039 self.registry = threading.local()
1040
1041 def __call__(self):
1042 try:
1043 return self.registry.value
1044 except AttributeError:
1045 val = self.registry.value = self.createfunc()
1046 return val
1047
1048 def has(self):
1049 return hasattr(self.registry, "value")
1050
1051 def set(self, obj):
1052 self.registry.value = obj
1053
1054 def clear(self):
1055 try:
1056 del self.registry.value
1057 except AttributeError:
1058 pass
1059
1060
1061def has_dupes(sequence, target):
1062 """Given a sequence and search object, return True if there's more
1063 than one, False if zero or one of them.
1064
1065
1066 """
1067 # compare to .index version below, this version introduces less function
1068 # overhead and is usually the same speed. At 15000 items (way bigger than
1069 # a relationship-bound collection in memory usually is) it begins to
1070 # fall behind the other version only by microseconds.
1071 c = 0
1072 for item in sequence:
1073 if item is target:
1074 c += 1
1075 if c > 1:
1076 return True
1077 return False
1078
1079
1080# .index version. the two __contains__ calls as well
1081# as .index() and isinstance() slow this down.
1082# def has_dupes(sequence, target):
1083# if target not in sequence:
1084# return False
1085# elif not isinstance(sequence, collections_abc.Sequence):
1086# return False
1087#
1088# idx = sequence.index(target)
1089# return target in sequence[idx + 1:]