1# util/_collections.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Collection classes and helpers."""
9
10from __future__ import absolute_import
11
12import operator
13import types
14import weakref
15
16from .compat import binary_types
17from .compat import collections_abc
18from .compat import itertools_filterfalse
19from .compat import py2k
20from .compat import string_types
21from .compat import threading
22
23
24EMPTY_SET = frozenset()
25
26
27class AbstractKeyedTuple(tuple):
28 __slots__ = ()
29
30 def keys(self):
31 """Return a list of string key names for this :class:`.KeyedTuple`.
32
33 .. seealso::
34
35 :attr:`.KeyedTuple._fields`
36
37 """
38
39 return list(self._fields)
40
41
42class KeyedTuple(AbstractKeyedTuple):
43 """``tuple`` subclass that adds labeled names.
44
45 E.g.::
46
47 >>> k = KeyedTuple([1, 2, 3], labels=["one", "two", "three"])
48 >>> k.one
49 1
50 >>> k.two
51 2
52
53 Result rows returned by :class:`_query.Query` that contain multiple
54 ORM entities and/or column expressions make use of this
55 class to return rows.
56
57 The :class:`.KeyedTuple` exhibits similar behavior to the
58 ``collections.namedtuple()`` construct provided in the Python
59 standard library, however is architected very differently.
60 Unlike ``collections.namedtuple()``, :class:`.KeyedTuple` is
61 does not rely on creation of custom subtypes in order to represent
62 a new series of keys, instead each :class:`.KeyedTuple` instance
63 receives its list of keys in place. The subtype approach
64 of ``collections.namedtuple()`` introduces significant complexity
65 and performance overhead, which is not necessary for the
66 :class:`_query.Query` object's use case.
67
68 .. seealso::
69
70 :ref:`ormtutorial_querying`
71
72 """
73
74 def __new__(cls, vals, labels=None):
75 t = tuple.__new__(cls, vals)
76 if labels:
77 t.__dict__.update(zip(labels, vals))
78 else:
79 labels = []
80 t.__dict__["_labels"] = labels
81 return t
82
83 @property
84 def _fields(self):
85 """Return a tuple of string key names for this :class:`.KeyedTuple`.
86
87 This method provides compatibility with ``collections.namedtuple()``.
88
89 .. seealso::
90
91 :meth:`.KeyedTuple.keys`
92
93 """
94 return tuple([l for l in self._labels if l is not None])
95
96 def __setattr__(self, key, value):
97 raise AttributeError("Can't set attribute: %s" % key)
98
99 def _asdict(self):
100 """Return the contents of this :class:`.KeyedTuple` as a dictionary.
101
102 This method provides compatibility with ``collections.namedtuple()``,
103 with the exception that the dictionary returned is **not** ordered.
104
105 """
106 return {key: self.__dict__[key] for key in self.keys()}
107
108
109class _LW(AbstractKeyedTuple):
110 __slots__ = ()
111
112 def __new__(cls, vals):
113 return tuple.__new__(cls, vals)
114
115 def __reduce__(self):
116 # for pickling, degrade down to the regular
117 # KeyedTuple, thus avoiding anonymous class pickling
118 # difficulties
119 return KeyedTuple, (list(self), self._real_fields)
120
121 def _asdict(self):
122 """Return the contents of this :class:`.KeyedTuple` as a dictionary."""
123
124 d = dict(zip(self._real_fields, self))
125 d.pop(None, None)
126 return d
127
128
129class ImmutableContainer(object):
130 def _immutable(self, *arg, **kw):
131 raise TypeError("%s object is immutable" % self.__class__.__name__)
132
133 __delitem__ = __setitem__ = __setattr__ = _immutable
134
135
136class immutabledict(ImmutableContainer, dict):
137
138 clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
139
140 def __new__(cls, *args):
141 new = dict.__new__(cls)
142 dict.__init__(new, *args)
143 return new
144
145 def __init__(self, *args):
146 pass
147
148 def __reduce__(self):
149 return immutabledict, (dict(self),)
150
151 def union(self, d):
152 if not d:
153 return self
154 elif not self:
155 if isinstance(d, immutabledict):
156 return d
157 else:
158 return immutabledict(d)
159 else:
160 d2 = immutabledict(self)
161 dict.update(d2, d)
162 return d2
163
164 def __repr__(self):
165 return "immutabledict(%s)" % dict.__repr__(self)
166
167
168class Properties(object):
169 """Provide a __getattr__/__setattr__ interface over a dict."""
170
171 __slots__ = ("_data",)
172
173 def __init__(self, data):
174 object.__setattr__(self, "_data", data)
175
176 def __len__(self):
177 return len(self._data)
178
179 def __iter__(self):
180 return iter(list(self._data.values()))
181
182 def __dir__(self):
183 return dir(super(Properties, self)) + [
184 str(k) for k in self._data.keys()
185 ]
186
187 def __add__(self, other):
188 return list(self) + list(other)
189
190 def __setitem__(self, key, obj):
191 self._data[key] = obj
192
193 def __getitem__(self, key):
194 return self._data[key]
195
196 def __delitem__(self, key):
197 del self._data[key]
198
199 def __setattr__(self, key, obj):
200 self._data[key] = obj
201
202 def __getstate__(self):
203 return {"_data": self._data}
204
205 def __setstate__(self, state):
206 object.__setattr__(self, "_data", state["_data"])
207
208 def __getattr__(self, key):
209 try:
210 return self._data[key]
211 except KeyError:
212 raise AttributeError(key)
213
214 def __contains__(self, key):
215 return key in self._data
216
217 def as_immutable(self):
218 """Return an immutable proxy for this :class:`.Properties`."""
219
220 return ImmutableProperties(self._data)
221
222 def update(self, value):
223 self._data.update(value)
224
225 def get(self, key, default=None):
226 if key in self:
227 return self[key]
228 else:
229 return default
230
231 def keys(self):
232 return list(self._data)
233
234 def values(self):
235 return list(self._data.values())
236
237 def items(self):
238 return list(self._data.items())
239
240 def has_key(self, key):
241 return key in self._data
242
243 def clear(self):
244 self._data.clear()
245
246
247class OrderedProperties(Properties):
248 """Provide a __getattr__/__setattr__ interface with an OrderedDict
249 as backing store."""
250
251 __slots__ = ()
252
253 def __init__(self):
254 Properties.__init__(self, OrderedDict())
255
256
257class ImmutableProperties(ImmutableContainer, Properties):
258 """Provide immutable dict/object attribute to an underlying dictionary."""
259
260 __slots__ = ()
261
262
263class OrderedDict(dict):
264 """A dict that returns keys/values/items in the order they were added."""
265
266 __slots__ = ("_list",)
267
268 def __reduce__(self):
269 return OrderedDict, (self.items(),)
270
271 def __init__(self, ____sequence=None, **kwargs):
272 self._list = []
273 if ____sequence is None:
274 if kwargs:
275 self.update(**kwargs)
276 else:
277 self.update(____sequence, **kwargs)
278
279 def clear(self):
280 self._list = []
281 dict.clear(self)
282
283 def copy(self):
284 return self.__copy__()
285
286 def __copy__(self):
287 return OrderedDict(self)
288
289 def sort(self, *arg, **kw):
290 self._list.sort(*arg, **kw)
291
292 def update(self, ____sequence=None, **kwargs):
293 if ____sequence is not None:
294 if hasattr(____sequence, "keys"):
295 for key in ____sequence.keys():
296 self.__setitem__(key, ____sequence[key])
297 else:
298 for key, value in ____sequence:
299 self[key] = value
300 if kwargs:
301 self.update(kwargs)
302
303 def setdefault(self, key, value):
304 if key not in self:
305 self.__setitem__(key, value)
306 return value
307 else:
308 return self.__getitem__(key)
309
310 def __iter__(self):
311 return iter(self._list)
312
313 def keys(self):
314 return list(self)
315
316 def values(self):
317 return [self[key] for key in self._list]
318
319 def items(self):
320 return [(key, self[key]) for key in self._list]
321
322 if py2k:
323
324 def itervalues(self):
325 return iter(self.values())
326
327 def iterkeys(self):
328 return iter(self)
329
330 def iteritems(self):
331 return iter(self.items())
332
333 def __setitem__(self, key, obj):
334 if key not in self:
335 try:
336 self._list.append(key)
337 except AttributeError:
338 # work around Python pickle loads() with
339 # dict subclass (seems to ignore __setstate__?)
340 self._list = [key]
341 dict.__setitem__(self, key, obj)
342
343 def __delitem__(self, key):
344 dict.__delitem__(self, key)
345 self._list.remove(key)
346
347 def pop(self, key, *default):
348 present = key in self
349 value = dict.pop(self, key, *default)
350 if present:
351 self._list.remove(key)
352 return value
353
354 def popitem(self):
355 item = dict.popitem(self)
356 self._list.remove(item[0])
357 return item
358
359
360class OrderedSet(set):
361 def __init__(self, d=None):
362 set.__init__(self)
363 self._list = []
364 if d is not None:
365 self._list = unique_list(d)
366 set.update(self, self._list)
367 else:
368 self._list = []
369
370 def add(self, element):
371 if element not in self:
372 self._list.append(element)
373 set.add(self, element)
374
375 def remove(self, element):
376 set.remove(self, element)
377 self._list.remove(element)
378
379 def insert(self, pos, element):
380 if element not in self:
381 self._list.insert(pos, element)
382 set.add(self, element)
383
384 def discard(self, element):
385 if element in self:
386 self._list.remove(element)
387 set.remove(self, element)
388
389 def clear(self):
390 set.clear(self)
391 self._list = []
392
393 def __getitem__(self, key):
394 return self._list[key]
395
396 def __iter__(self):
397 return iter(self._list)
398
399 def __add__(self, other):
400 return self.union(other)
401
402 def __repr__(self):
403 return "%s(%r)" % (self.__class__.__name__, self._list)
404
405 __str__ = __repr__
406
407 def update(self, iterable):
408 for e in iterable:
409 if e not in self:
410 self._list.append(e)
411 set.add(self, e)
412 return self
413
414 __ior__ = update
415
416 def union(self, other):
417 result = self.__class__(self)
418 result.update(other)
419 return result
420
421 __or__ = union
422
423 def intersection(self, other):
424 other = set(other)
425 return self.__class__(a for a in self if a in other)
426
427 __and__ = intersection
428
429 def symmetric_difference(self, other):
430 other = set(other)
431 result = self.__class__(a for a in self if a not in other)
432 result.update(a for a in other if a not in self)
433 return result
434
435 __xor__ = symmetric_difference
436
437 def difference(self, other):
438 other = set(other)
439 return self.__class__(a for a in self if a not in other)
440
441 __sub__ = difference
442
443 def intersection_update(self, other):
444 other = set(other)
445 set.intersection_update(self, other)
446 self._list = [a for a in self._list if a in other]
447 return self
448
449 __iand__ = intersection_update
450
451 def symmetric_difference_update(self, other):
452 set.symmetric_difference_update(self, other)
453 self._list = [a for a in self._list if a in self]
454 self._list += [a for a in other._list if a in self]
455 return self
456
457 __ixor__ = symmetric_difference_update
458
459 def difference_update(self, other):
460 set.difference_update(self, other)
461 self._list = [a for a in self._list if a in self]
462 return self
463
464 __isub__ = difference_update
465
466
467class IdentitySet(object):
468 """A set that considers only object id() for uniqueness.
469
470 This strategy has edge cases for builtin types- it's possible to have
471 two 'foo' strings in one of these sets, for example. Use sparingly.
472
473 """
474
475 def __init__(self, iterable=None):
476 self._members = dict()
477 if iterable:
478 self.update(iterable)
479
480 def add(self, value):
481 self._members[id(value)] = value
482
483 def __contains__(self, value):
484 return id(value) in self._members
485
486 def remove(self, value):
487 del self._members[id(value)]
488
489 def discard(self, value):
490 try:
491 self.remove(value)
492 except KeyError:
493 pass
494
495 def pop(self):
496 try:
497 pair = self._members.popitem()
498 return pair[1]
499 except KeyError:
500 raise KeyError("pop from an empty set")
501
502 def clear(self):
503 self._members.clear()
504
505 def __cmp__(self, other):
506 raise TypeError("cannot compare sets using cmp()")
507
508 def __eq__(self, other):
509 if isinstance(other, IdentitySet):
510 return self._members == other._members
511 else:
512 return False
513
514 def __ne__(self, other):
515 if isinstance(other, IdentitySet):
516 return self._members != other._members
517 else:
518 return True
519
520 def issubset(self, iterable):
521 other = self.__class__(iterable)
522
523 if len(self) > len(other):
524 return False
525 for m in itertools_filterfalse(
526 other._members.__contains__, iter(self._members.keys())
527 ):
528 return False
529 return True
530
531 def __le__(self, other):
532 if not isinstance(other, IdentitySet):
533 return NotImplemented
534 return self.issubset(other)
535
536 def __lt__(self, other):
537 if not isinstance(other, IdentitySet):
538 return NotImplemented
539 return len(self) < len(other) and self.issubset(other)
540
541 def issuperset(self, iterable):
542 other = self.__class__(iterable)
543
544 if len(self) < len(other):
545 return False
546
547 for m in itertools_filterfalse(
548 self._members.__contains__, iter(other._members.keys())
549 ):
550 return False
551 return True
552
553 def __ge__(self, other):
554 if not isinstance(other, IdentitySet):
555 return NotImplemented
556 return self.issuperset(other)
557
558 def __gt__(self, other):
559 if not isinstance(other, IdentitySet):
560 return NotImplemented
561 return len(self) > len(other) and self.issuperset(other)
562
563 def union(self, iterable):
564 result = self.__class__()
565 members = self._members
566 result._members.update(members)
567 result._members.update((id(obj), obj) for obj in iterable)
568 return result
569
570 def __or__(self, other):
571 if not isinstance(other, IdentitySet):
572 return NotImplemented
573 return self.union(other)
574
575 def update(self, iterable):
576 self._members.update((id(obj), obj) for obj in iterable)
577
578 def __ior__(self, other):
579 if not isinstance(other, IdentitySet):
580 return NotImplemented
581 self.update(other)
582 return self
583
584 def difference(self, iterable):
585 result = self.__class__()
586 members = self._members
587 other = {id(obj) for obj in iterable}
588 result._members.update(
589 ((k, v) for k, v in members.items() if k not in other)
590 )
591 return result
592
593 def __sub__(self, other):
594 if not isinstance(other, IdentitySet):
595 return NotImplemented
596 return self.difference(other)
597
598 def difference_update(self, iterable):
599 self._members = self.difference(iterable)._members
600
601 def __isub__(self, other):
602 if not isinstance(other, IdentitySet):
603 return NotImplemented
604 self.difference_update(other)
605 return self
606
607 def intersection(self, iterable):
608 result = self.__class__()
609 members = self._members
610 other = {id(obj) for obj in iterable}
611 result._members.update(
612 (k, v) for k, v in members.items() if k in other
613 )
614 return result
615
616 def __and__(self, other):
617 if not isinstance(other, IdentitySet):
618 return NotImplemented
619 return self.intersection(other)
620
621 def intersection_update(self, iterable):
622 self._members = self.intersection(iterable)._members
623
624 def __iand__(self, other):
625 if not isinstance(other, IdentitySet):
626 return NotImplemented
627 self.intersection_update(other)
628 return self
629
630 def symmetric_difference(self, iterable):
631 result = self.__class__()
632 members = self._members
633 other = {id(obj): obj for obj in iterable}
634 result._members.update(
635 ((k, v) for k, v in members.items() if k not in other)
636 )
637 result._members.update(
638 ((k, v) for k, v in other.items() if k not in members)
639 )
640 return result
641
642 def __xor__(self, other):
643 if not isinstance(other, IdentitySet):
644 return NotImplemented
645 return self.symmetric_difference(other)
646
647 def symmetric_difference_update(self, iterable):
648 self._members = self.symmetric_difference(iterable)._members
649
650 def __ixor__(self, other):
651 if not isinstance(other, IdentitySet):
652 return NotImplemented
653 self.symmetric_difference(other)
654 return self
655
656 def copy(self):
657 return type(self)(iter(self._members.values()))
658
659 __copy__ = copy
660
661 def __len__(self):
662 return len(self._members)
663
664 def __iter__(self):
665 return iter(self._members.values())
666
667 def __hash__(self):
668 raise TypeError("set objects are unhashable")
669
670 def __repr__(self):
671 return "%s(%r)" % (type(self).__name__, list(self._members.values()))
672
673
674class WeakSequence(object):
675 def __init__(self, __elements=()):
676 # adapted from weakref.WeakKeyDictionary, prevent reference
677 # cycles in the collection itself
678 def _remove(item, selfref=weakref.ref(self)):
679 self = selfref()
680 if self is not None:
681 self._storage.remove(item)
682
683 self._remove = _remove
684 self._storage = [
685 weakref.ref(element, _remove) for element in __elements
686 ]
687
688 def append(self, item):
689 self._storage.append(weakref.ref(item, self._remove))
690
691 def __len__(self):
692 return len(self._storage)
693
694 def __iter__(self):
695 return (
696 obj for obj in (ref() for ref in self._storage) if obj is not None
697 )
698
699 def __getitem__(self, index):
700 try:
701 obj = self._storage[index]
702 except KeyError:
703 raise IndexError("Index %s out of range" % index)
704 else:
705 return obj()
706
707
708class OrderedIdentitySet(IdentitySet):
709 def __init__(self, iterable=None):
710 IdentitySet.__init__(self)
711 self._members = OrderedDict()
712 if iterable:
713 for o in iterable:
714 self.add(o)
715
716
717class PopulateDict(dict):
718 """A dict which populates missing values via a creation function.
719
720 Note the creation function takes a key, unlike
721 collections.defaultdict.
722
723 """
724
725 def __init__(self, creator):
726 self.creator = creator
727
728 def __missing__(self, key):
729 self[key] = val = self.creator(key)
730 return val
731
732
733class WeakPopulateDict(dict):
734 """Like PopulateDict, but assumes a self + a method and does not create
735 a reference cycle.
736
737 """
738
739 def __init__(self, creator_method):
740 self.creator = creator_method.__func__
741 weakself = creator_method.__self__
742 self.weakself = weakref.ref(weakself)
743
744 def __missing__(self, key):
745 self[key] = val = self.creator(self.weakself(), key)
746 return val
747
748
749# Define collections that are capable of storing
750# ColumnElement objects as hashable keys/elements.
751# At this point, these are mostly historical, things
752# used to be more complicated.
753column_set = set
754column_dict = dict
755ordered_column_set = OrderedSet
756
757
758_getters = PopulateDict(operator.itemgetter)
759
760_property_getters = PopulateDict(
761 lambda idx: property(operator.itemgetter(idx))
762)
763
764
765def unique_list(seq, hashfunc=None):
766 seen = set()
767 seen_add = seen.add
768 if not hashfunc:
769 return [x for x in seq if x not in seen and not seen_add(x)]
770 else:
771 return [
772 x
773 for x in seq
774 if hashfunc(x) not in seen and not seen_add(hashfunc(x))
775 ]
776
777
778class UniqueAppender(object):
779 """Appends items to a collection ensuring uniqueness.
780
781 Additional appends() of the same object are ignored. Membership is
782 determined by identity (``is a``) not equality (``==``).
783 """
784
785 def __init__(self, data, via=None):
786 self.data = data
787 self._unique = {}
788 if via:
789 self._data_appender = getattr(data, via)
790 elif hasattr(data, "append"):
791 self._data_appender = data.append
792 elif hasattr(data, "add"):
793 self._data_appender = data.add
794
795 def append(self, item):
796 id_ = id(item)
797 if id_ not in self._unique:
798 self._data_appender(item)
799 self._unique[id_] = True
800
801 def __iter__(self):
802 return iter(self.data)
803
804
805def coerce_generator_arg(arg):
806 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
807 return list(arg[0])
808 else:
809 return arg
810
811
812def to_list(x, default=None):
813 if x is None:
814 return default
815 if not isinstance(x, collections_abc.Iterable) or isinstance(
816 x, string_types + binary_types
817 ):
818 return [x]
819 elif isinstance(x, list):
820 return x
821 else:
822 return list(x)
823
824
825def has_intersection(set_, iterable):
826 r"""return True if any items of set\_ are present in iterable.
827
828 Goes through special effort to ensure __hash__ is not called
829 on items in iterable that don't support it.
830
831 """
832 # TODO: optimize, write in C, etc.
833 return bool(set_.intersection([i for i in iterable if i.__hash__]))
834
835
836def to_set(x):
837 if x is None:
838 return set()
839 if not isinstance(x, set):
840 return set(to_list(x))
841 else:
842 return x
843
844
845def to_column_set(x):
846 if x is None:
847 return column_set()
848 if not isinstance(x, column_set):
849 return column_set(to_list(x))
850 else:
851 return x
852
853
854def update_copy(d, _new=None, **kw):
855 """Copy the given dict and update with the given values."""
856
857 d = d.copy()
858 if _new:
859 d.update(_new)
860 d.update(**kw)
861 return d
862
863
864def flatten_iterator(x):
865 """Given an iterator of which further sub-elements may also be
866 iterators, flatten the sub-elements into a single iterator.
867
868 """
869 for elem in x:
870 if not isinstance(elem, str) and hasattr(elem, "__iter__"):
871 for y in flatten_iterator(elem):
872 yield y
873 else:
874 yield elem
875
876
877class LRUCache(dict):
878 """Dictionary with 'squishy' removal of least
879 recently used items.
880
881 Note that either get() or [] should be used here, but
882 generally its not safe to do an "in" check first as the dictionary
883 can change subsequent to that call.
884
885 """
886
887 __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex"
888
889 def __init__(self, capacity=100, threshold=0.5, size_alert=None):
890 self.capacity = capacity
891 self.threshold = threshold
892 self.size_alert = size_alert
893 self._counter = 0
894 self._mutex = threading.Lock()
895
896 def _inc_counter(self):
897 self._counter += 1
898 return self._counter
899
900 def get(self, key, default=None):
901 item = dict.get(self, key, default)
902 if item is not default:
903 item[2] = self._inc_counter()
904 return item[1]
905 else:
906 return default
907
908 def __getitem__(self, key):
909 item = dict.__getitem__(self, key)
910 item[2] = self._inc_counter()
911 return item[1]
912
913 def values(self):
914 return [i[1] for i in dict.values(self)]
915
916 def setdefault(self, key, value):
917 if key in self:
918 return self[key]
919 else:
920 self[key] = value
921 return value
922
923 def __setitem__(self, key, value):
924 item = dict.get(self, key)
925 if item is None:
926 item = [key, value, self._inc_counter()]
927 dict.__setitem__(self, key, item)
928 else:
929 item[1] = value
930 self._manage_size()
931
932 @property
933 def size_threshold(self):
934 return self.capacity + self.capacity * self.threshold
935
936 def _manage_size(self):
937 if not self._mutex.acquire(False):
938 return
939 try:
940 size_alert = bool(self.size_alert)
941 while len(self) > self.capacity + self.capacity * self.threshold:
942 if size_alert:
943 size_alert = False
944 self.size_alert(self)
945 by_counter = sorted(
946 dict.values(self), key=operator.itemgetter(2), reverse=True
947 )
948 for item in by_counter[self.capacity :]:
949 try:
950 del self[item[0]]
951 except KeyError:
952 # deleted elsewhere; skip
953 continue
954 finally:
955 self._mutex.release()
956
957
958_lw_tuples = LRUCache(100)
959
960
961def lightweight_named_tuple(name, fields):
962 hash_ = (name,) + tuple(fields)
963 tp_cls = _lw_tuples.get(hash_)
964 if tp_cls:
965 return tp_cls
966
967 tp_cls = type(
968 name,
969 (_LW,),
970 dict(
971 [
972 (field, _property_getters[idx])
973 for idx, field in enumerate(fields)
974 if field is not None
975 ]
976 + [("__slots__", ())]
977 ),
978 )
979
980 tp_cls._real_fields = fields
981 tp_cls._fields = tuple([f for f in fields if f is not None])
982
983 _lw_tuples[hash_] = tp_cls
984 return tp_cls
985
986
987class ScopedRegistry(object):
988 """A Registry that can store one or multiple instances of a single
989 class on the basis of a "scope" function.
990
991 The object implements ``__call__`` as the "getter", so by
992 calling ``myregistry()`` the contained object is returned
993 for the current scope.
994
995 :param createfunc:
996 a callable that returns a new object to be placed in the registry
997
998 :param scopefunc:
999 a callable that will return a key to store/retrieve an object.
1000 """
1001
1002 def __init__(self, createfunc, scopefunc):
1003 """Construct a new :class:`.ScopedRegistry`.
1004
1005 :param createfunc: A creation function that will generate
1006 a new value for the current scope, if none is present.
1007
1008 :param scopefunc: A function that returns a hashable
1009 token representing the current scope (such as, current
1010 thread identifier).
1011
1012 """
1013 self.createfunc = createfunc
1014 self.scopefunc = scopefunc
1015 self.registry = {}
1016
1017 def __call__(self):
1018 key = self.scopefunc()
1019 try:
1020 return self.registry[key]
1021 except KeyError:
1022 return self.registry.setdefault(key, self.createfunc())
1023
1024 def has(self):
1025 """Return True if an object is present in the current scope."""
1026
1027 return self.scopefunc() in self.registry
1028
1029 def set(self, obj):
1030 """Set the value for the current scope."""
1031
1032 self.registry[self.scopefunc()] = obj
1033
1034 def clear(self):
1035 """Clear the current scope, if any."""
1036
1037 try:
1038 del self.registry[self.scopefunc()]
1039 except KeyError:
1040 pass
1041
1042
1043class ThreadLocalRegistry(ScopedRegistry):
1044 """A :class:`.ScopedRegistry` that uses a ``threading.local()``
1045 variable for storage.
1046
1047 """
1048
1049 def __init__(self, createfunc):
1050 self.createfunc = createfunc
1051 self.registry = threading.local()
1052
1053 def __call__(self):
1054 try:
1055 return self.registry.value
1056 except AttributeError:
1057 val = self.registry.value = self.createfunc()
1058 return val
1059
1060 def has(self):
1061 return hasattr(self.registry, "value")
1062
1063 def set(self, obj):
1064 self.registry.value = obj
1065
1066 def clear(self):
1067 try:
1068 del self.registry.value
1069 except AttributeError:
1070 pass
1071
1072
1073def has_dupes(sequence, target):
1074 """Given a sequence and search object, return True if there's more
1075 than one, False if zero or one of them.
1076
1077
1078 """
1079 # compare to .index version below, this version introduces less function
1080 # overhead and is usually the same speed. At 15000 items (way bigger than
1081 # a relationship-bound collection in memory usually is) it begins to
1082 # fall behind the other version only by microseconds.
1083 c = 0
1084 for item in sequence:
1085 if item is target:
1086 c += 1
1087 if c > 1:
1088 return True
1089 return False
1090
1091
1092# .index version. the two __contains__ calls as well
1093# as .index() and isinstance() slow this down.
1094# def has_dupes(sequence, target):
1095# if target not in sequence:
1096# return False
1097# elif not isinstance(sequence, collections_abc.Sequence):
1098# return False
1099#
1100# idx = sequence.index(target)
1101# return target in sequence[idx + 1:]