1# ext/associationproxy.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Contain the ``AssociationProxy`` class.
9
10The ``AssociationProxy`` is a Python property object which provides
11transparent proxied access to the endpoint of an association object.
12
13See the example ``examples/association/proxied_association.py``.
14
15"""
16import operator
17
18from .. import exc
19from .. import inspect
20from .. import orm
21from .. import util
22from ..orm import collections
23from ..orm import interfaces
24from ..sql import or_
25from ..sql.operators import ColumnOperators
26
27
28def association_proxy(target_collection, attr, **kw):
29 r"""Return a Python property implementing a view of a target
30 attribute which references an attribute on members of the
31 target.
32
33 The returned value is an instance of :class:`.AssociationProxy`.
34
35 Implements a Python property representing a relationship as a collection
36 of simpler values, or a scalar value. The proxied property will mimic
37 the collection type of the target (list, dict or set), or, in the case of
38 a one to one relationship, a simple scalar value.
39
40 :param target_collection: Name of the attribute we'll proxy to.
41 This attribute is typically mapped by
42 :func:`~sqlalchemy.orm.relationship` to link to a target collection, but
43 can also be a many-to-one or non-scalar relationship.
44
45 :param attr: Attribute on the associated instance or instances we'll
46 proxy for.
47
48 For example, given a target collection of [obj1, obj2], a list created
49 by this proxy property would look like [getattr(obj1, *attr*),
50 getattr(obj2, *attr*)]
51
52 If the relationship is one-to-one or otherwise uselist=False, then
53 simply: getattr(obj, *attr*)
54
55 :param creator: optional.
56
57 When new items are added to this proxied collection, new instances of
58 the class collected by the target collection will be created. For list
59 and set collections, the target class constructor will be called with
60 the 'value' for the new instance. For dict types, two arguments are
61 passed: key and value.
62
63 If you want to construct instances differently, supply a *creator*
64 function that takes arguments as above and returns instances.
65
66 For scalar relationships, creator() will be called if the target is None.
67 If the target is present, set operations are proxied to setattr() on the
68 associated object.
69
70 If you have an associated object with multiple attributes, you may set
71 up multiple association proxies mapping to different attributes. See
72 the unit tests for examples, and for examples of how creator() functions
73 can be used to construct the scalar relationship on-demand in this
74 situation.
75
76 :param \*\*kw: Passes along any other keyword arguments to
77 :class:`.AssociationProxy`.
78
79 """
80 return AssociationProxy(target_collection, attr, **kw)
81
82
83ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY")
84"""Symbol indicating an :class:`.InspectionAttr` that's
85 of type :class:`.AssociationProxy`.
86
87 Is assigned to the :attr:`.InspectionAttr.extension_type`
88 attribute.
89
90"""
91
92
93class AssociationProxy(interfaces.InspectionAttrInfo):
94 """A descriptor that presents a read/write view of an object attribute."""
95
96 is_attribute = True
97 extension_type = ASSOCIATION_PROXY
98
99 def __init__(
100 self,
101 target_collection,
102 attr,
103 creator=None,
104 getset_factory=None,
105 proxy_factory=None,
106 proxy_bulk_set=None,
107 info=None,
108 cascade_scalar_deletes=False,
109 ):
110 """Construct a new :class:`.AssociationProxy`.
111
112 The :func:`.association_proxy` function is provided as the usual
113 entrypoint here, though :class:`.AssociationProxy` can be instantiated
114 and/or subclassed directly.
115
116 :param target_collection: Name of the collection we'll proxy to,
117 usually created with :func:`_orm.relationship`.
118
119 :param attr: Attribute on the collected instances we'll proxy
120 for. For example, given a target collection of [obj1, obj2], a
121 list created by this proxy property would look like
122 [getattr(obj1, attr), getattr(obj2, attr)]
123
124 :param creator: Optional. When new items are added to this proxied
125 collection, new instances of the class collected by the target
126 collection will be created. For list and set collections, the
127 target class constructor will be called with the 'value' for the
128 new instance. For dict types, two arguments are passed:
129 key and value.
130
131 If you want to construct instances differently, supply a 'creator'
132 function that takes arguments as above and returns instances.
133
134 :param cascade_scalar_deletes: when True, indicates that setting
135 the proxied value to ``None``, or deleting it via ``del``, should
136 also remove the source object. Only applies to scalar attributes.
137 Normally, removing the proxied target will not remove the proxy
138 source, as this object may have other state that is still to be
139 kept.
140
141 .. versionadded:: 1.3
142
143 .. seealso::
144
145 :ref:`cascade_scalar_deletes` - complete usage example
146
147 :param getset_factory: Optional. Proxied attribute access is
148 automatically handled by routines that get and set values based on
149 the `attr` argument for this proxy.
150
151 If you would like to customize this behavior, you may supply a
152 `getset_factory` callable that produces a tuple of `getter` and
153 `setter` functions. The factory is called with two arguments, the
154 abstract type of the underlying collection and this proxy instance.
155
156 :param proxy_factory: Optional. The type of collection to emulate is
157 determined by sniffing the target collection. If your collection
158 type can't be determined by duck typing or you'd like to use a
159 different collection implementation, you may supply a factory
160 function to produce those collections. Only applicable to
161 non-scalar relationships.
162
163 :param proxy_bulk_set: Optional, use with proxy_factory. See
164 the _set() method for details.
165
166 :param info: optional, will be assigned to
167 :attr:`.AssociationProxy.info` if present.
168
169 .. versionadded:: 1.0.9
170
171 """
172 self.target_collection = target_collection
173 self.value_attr = attr
174 self.creator = creator
175 self.getset_factory = getset_factory
176 self.proxy_factory = proxy_factory
177 self.proxy_bulk_set = proxy_bulk_set
178 self.cascade_scalar_deletes = cascade_scalar_deletes
179
180 self.key = "_%s_%s_%s" % (
181 type(self).__name__,
182 target_collection,
183 id(self),
184 )
185 if info:
186 self.info = info
187
188 def __get__(self, obj, class_):
189 if class_ is None:
190 return self
191 inst = self._as_instance(class_, obj)
192 if inst:
193 return inst.get(obj)
194
195 # obj has to be None here
196 # assert obj is None
197
198 return self
199
200 def __set__(self, obj, values):
201 class_ = type(obj)
202 return self._as_instance(class_, obj).set(obj, values)
203
204 def __delete__(self, obj):
205 class_ = type(obj)
206 return self._as_instance(class_, obj).delete(obj)
207
208 def for_class(self, class_, obj=None):
209 r"""Return the internal state local to a specific mapped class.
210
211 E.g., given a class ``User``::
212
213 class User(Base):
214 # ...
215
216 keywords = association_proxy('kws', 'keyword')
217
218 If we access this :class:`.AssociationProxy` from
219 :attr:`_orm.Mapper.all_orm_descriptors`, and we want to view the
220 target class for this proxy as mapped by ``User``::
221
222 inspect(User).all_orm_descriptors["keywords"].for_class(User).target_class
223
224 This returns an instance of :class:`.AssociationProxyInstance` that
225 is specific to the ``User`` class. The :class:`.AssociationProxy`
226 object remains agnostic of its parent class.
227
228 :param class\_: the class that we are returning state for.
229
230 :param obj: optional, an instance of the class that is required
231 if the attribute refers to a polymorphic target, e.g. where we have
232 to look at the type of the actual destination object to get the
233 complete path.
234
235 .. versionadded:: 1.3 - :class:`.AssociationProxy` no longer stores
236 any state specific to a particular parent class; the state is now
237 stored in per-class :class:`.AssociationProxyInstance` objects.
238
239
240 """
241 return self._as_instance(class_, obj)
242
243 def _as_instance(self, class_, obj):
244 try:
245 inst = class_.__dict__[self.key + "_inst"]
246 except KeyError:
247 inst = None
248
249 # avoid exception context
250 if inst is None:
251 owner = self._calc_owner(class_)
252 if owner is not None:
253 inst = AssociationProxyInstance.for_proxy(self, owner, obj)
254 setattr(class_, self.key + "_inst", inst)
255 else:
256 inst = None
257
258 if inst is not None and not inst._is_canonical:
259 # the AssociationProxyInstance can't be generalized
260 # since the proxied attribute is not on the targeted
261 # class, only on subclasses of it, which might be
262 # different. only return for the specific
263 # object's current value
264 return inst._non_canonical_get_for_object(obj)
265 else:
266 return inst
267
268 def _calc_owner(self, target_cls):
269 # we might be getting invoked for a subclass
270 # that is not mapped yet, in some declarative situations.
271 # save until we are mapped
272 try:
273 insp = inspect(target_cls)
274 except exc.NoInspectionAvailable:
275 # can't find a mapper, don't set owner. if we are a not-yet-mapped
276 # subclass, we can also scan through __mro__ to find a mapped
277 # class, but instead just wait for us to be called again against a
278 # mapped class normally.
279 return None
280 else:
281 return insp.mapper.class_manager.class_
282
283 def _default_getset(self, collection_class):
284 attr = self.value_attr
285 _getter = operator.attrgetter(attr)
286
287 def getter(target):
288 return _getter(target) if target is not None else None
289
290 if collection_class is dict:
291
292 def setter(o, k, v):
293 setattr(o, attr, v)
294
295 else:
296
297 def setter(o, v):
298 setattr(o, attr, v)
299
300 return getter, setter
301
302 def __repr__(self):
303 return "AssociationProxy(%r, %r)" % (
304 self.target_collection,
305 self.value_attr,
306 )
307
308
309class AssociationProxyInstance(object):
310 """A per-class object that serves class- and object-specific results.
311
312 This is used by :class:`.AssociationProxy` when it is invoked
313 in terms of a specific class or instance of a class, i.e. when it is
314 used as a regular Python descriptor.
315
316 When referring to the :class:`.AssociationProxy` as a normal Python
317 descriptor, the :class:`.AssociationProxyInstance` is the object that
318 actually serves the information. Under normal circumstances, its presence
319 is transparent::
320
321 >>> User.keywords.scalar
322 False
323
324 In the special case that the :class:`.AssociationProxy` object is being
325 accessed directly, in order to get an explicit handle to the
326 :class:`.AssociationProxyInstance`, use the
327 :meth:`.AssociationProxy.for_class` method::
328
329 proxy_state = inspect(User).all_orm_descriptors["keywords"].for_class(User)
330
331 # view if proxy object is scalar or not
332 >>> proxy_state.scalar
333 False
334
335 .. versionadded:: 1.3
336
337 """ # noqa
338
339 def __init__(self, parent, owning_class, target_class, value_attr):
340 self.parent = parent
341 self.key = parent.key
342 self.owning_class = owning_class
343 self.target_collection = parent.target_collection
344 self.collection_class = None
345 self.target_class = target_class
346 self.value_attr = value_attr
347
348 target_class = None
349 """The intermediary class handled by this
350 :class:`.AssociationProxyInstance`.
351
352 Intercepted append/set/assignment events will result
353 in the generation of new instances of this class.
354
355 """
356
357 @classmethod
358 def for_proxy(cls, parent, owning_class, parent_instance):
359 target_collection = parent.target_collection
360 value_attr = parent.value_attr
361 prop = orm.class_mapper(owning_class).get_property(target_collection)
362
363 # this was never asserted before but this should be made clear.
364 if not isinstance(prop, orm.RelationshipProperty):
365 util.raise_(
366 NotImplementedError(
367 "association proxy to a non-relationship "
368 "intermediary is not supported"
369 ),
370 replace_context=None,
371 )
372
373 target_class = prop.mapper.class_
374
375 try:
376 target_assoc = cls._cls_unwrap_target_assoc_proxy(
377 target_class, value_attr
378 )
379 except AttributeError:
380 # the proxied attribute doesn't exist on the target class;
381 # return an "ambiguous" instance that will work on a per-object
382 # basis
383 return AmbiguousAssociationProxyInstance(
384 parent, owning_class, target_class, value_attr
385 )
386 except Exception as err:
387 util.raise_(
388 exc.InvalidRequestError(
389 "Association proxy received an unexpected error when "
390 "trying to retreive attribute "
391 '"%s.%s" from '
392 'class "%s": %s'
393 % (
394 target_class.__name__,
395 parent.value_attr,
396 target_class.__name__,
397 err,
398 )
399 ),
400 from_=err,
401 )
402 else:
403 return cls._construct_for_assoc(
404 target_assoc, parent, owning_class, target_class, value_attr
405 )
406
407 @classmethod
408 def _construct_for_assoc(
409 cls, target_assoc, parent, owning_class, target_class, value_attr
410 ):
411 if target_assoc is not None:
412 return ObjectAssociationProxyInstance(
413 parent, owning_class, target_class, value_attr
414 )
415
416 attr = getattr(target_class, value_attr)
417 if not hasattr(attr, "_is_internal_proxy"):
418 return AmbiguousAssociationProxyInstance(
419 parent, owning_class, target_class, value_attr
420 )
421 is_object = attr._impl_uses_objects
422 if is_object:
423 return ObjectAssociationProxyInstance(
424 parent, owning_class, target_class, value_attr
425 )
426 else:
427 return ColumnAssociationProxyInstance(
428 parent, owning_class, target_class, value_attr
429 )
430
431 def _get_property(self):
432 return orm.class_mapper(self.owning_class).get_property(
433 self.target_collection
434 )
435
436 @property
437 def _comparator(self):
438 return self._get_property().comparator
439
440 def __clause_element__(self):
441 raise NotImplementedError(
442 "The association proxy can't be used as a plain column "
443 "expression; it only works inside of a comparison expression"
444 )
445
446 @classmethod
447 def _cls_unwrap_target_assoc_proxy(cls, target_class, value_attr):
448 attr = getattr(target_class, value_attr)
449 if isinstance(attr, (AssociationProxy, AssociationProxyInstance)):
450 return attr
451 return None
452
453 @util.memoized_property
454 def _unwrap_target_assoc_proxy(self):
455 return self._cls_unwrap_target_assoc_proxy(
456 self.target_class, self.value_attr
457 )
458
459 @property
460 def remote_attr(self):
461 """The 'remote' class attribute referenced by this
462 :class:`.AssociationProxyInstance`.
463
464 .. seealso::
465
466 :attr:`.AssociationProxyInstance.attr`
467
468 :attr:`.AssociationProxyInstance.local_attr`
469
470 """
471 return getattr(self.target_class, self.value_attr)
472
473 @property
474 def local_attr(self):
475 """The 'local' class attribute referenced by this
476 :class:`.AssociationProxyInstance`.
477
478 .. seealso::
479
480 :attr:`.AssociationProxyInstance.attr`
481
482 :attr:`.AssociationProxyInstance.remote_attr`
483
484 """
485 return getattr(self.owning_class, self.target_collection)
486
487 @property
488 def attr(self):
489 """Return a tuple of ``(local_attr, remote_attr)``.
490
491 This attribute was originally intended to facilitate using the
492 :meth:`_query.Query.join` method to join across the two relationships
493 at once, however this makes use of a deprecated calling style.
494
495 To use :meth:`_sql.select.join` or :meth:`_orm.Query.join` with
496 an association proxy, the current method is to make use of the
497 :attr:`.AssociationProxyInstance.local_attr` and
498 :attr:`.AssociationProxyInstance.remote_attr` attributes separately::
499
500 stmt = (
501 select(Parent).
502 join(Parent.proxied.local_attr).
503 join(Parent.proxied.remote_attr)
504 )
505
506 A future release may seek to provide a more succinct join pattern
507 for association proxy attributes.
508
509 .. seealso::
510
511 :attr:`.AssociationProxyInstance.local_attr`
512
513 :attr:`.AssociationProxyInstance.remote_attr`
514
515 """
516 return (self.local_attr, self.remote_attr)
517
518 @util.memoized_property
519 def scalar(self):
520 """Return ``True`` if this :class:`.AssociationProxyInstance`
521 proxies a scalar relationship on the local side."""
522
523 scalar = not self._get_property().uselist
524 if scalar:
525 self._initialize_scalar_accessors()
526 return scalar
527
528 @util.memoized_property
529 def _value_is_scalar(self):
530 return (
531 not self._get_property()
532 .mapper.get_property(self.value_attr)
533 .uselist
534 )
535
536 @property
537 def _target_is_object(self):
538 raise NotImplementedError()
539
540 def _initialize_scalar_accessors(self):
541 if self.parent.getset_factory:
542 get, set_ = self.parent.getset_factory(None, self)
543 else:
544 get, set_ = self.parent._default_getset(None)
545 self._scalar_get, self._scalar_set = get, set_
546
547 def _default_getset(self, collection_class):
548 attr = self.value_attr
549 _getter = operator.attrgetter(attr)
550
551 def getter(target):
552 return _getter(target) if target is not None else None
553
554 if collection_class is dict:
555
556 def setter(o, k, v):
557 return setattr(o, attr, v)
558
559 else:
560
561 def setter(o, v):
562 return setattr(o, attr, v)
563
564 return getter, setter
565
566 @property
567 def info(self):
568 return self.parent.info
569
570 def get(self, obj):
571 if obj is None:
572 return self
573
574 if self.scalar:
575 target = getattr(obj, self.target_collection)
576 return self._scalar_get(target)
577 else:
578 try:
579 # If the owning instance is reborn (orm session resurrect,
580 # etc.), refresh the proxy cache.
581 creator_id, self_id, proxy = getattr(obj, self.key)
582 except AttributeError:
583 pass
584 else:
585 if id(obj) == creator_id and id(self) == self_id:
586 assert self.collection_class is not None
587 return proxy
588
589 self.collection_class, proxy = self._new(
590 _lazy_collection(obj, self.target_collection)
591 )
592 setattr(obj, self.key, (id(obj), id(self), proxy))
593 return proxy
594
595 def set(self, obj, values):
596 if self.scalar:
597 creator = (
598 self.parent.creator
599 if self.parent.creator
600 else self.target_class
601 )
602 target = getattr(obj, self.target_collection)
603 if target is None:
604 if values is None:
605 return
606 setattr(obj, self.target_collection, creator(values))
607 else:
608 self._scalar_set(target, values)
609 if values is None and self.parent.cascade_scalar_deletes:
610 setattr(obj, self.target_collection, None)
611 else:
612 proxy = self.get(obj)
613 assert self.collection_class is not None
614 if proxy is not values:
615 proxy._bulk_replace(self, values)
616
617 def delete(self, obj):
618 if self.owning_class is None:
619 self._calc_owner(obj, None)
620
621 if self.scalar:
622 target = getattr(obj, self.target_collection)
623 if target is not None:
624 delattr(target, self.value_attr)
625 delattr(obj, self.target_collection)
626
627 def _new(self, lazy_collection):
628 creator = (
629 self.parent.creator if self.parent.creator else self.target_class
630 )
631 collection_class = util.duck_type_collection(lazy_collection())
632
633 if self.parent.proxy_factory:
634 return (
635 collection_class,
636 self.parent.proxy_factory(
637 lazy_collection, creator, self.value_attr, self
638 ),
639 )
640
641 if self.parent.getset_factory:
642 getter, setter = self.parent.getset_factory(collection_class, self)
643 else:
644 getter, setter = self.parent._default_getset(collection_class)
645
646 if collection_class is list:
647 return (
648 collection_class,
649 _AssociationList(
650 lazy_collection, creator, getter, setter, self
651 ),
652 )
653 elif collection_class is dict:
654 return (
655 collection_class,
656 _AssociationDict(
657 lazy_collection, creator, getter, setter, self
658 ),
659 )
660 elif collection_class is set:
661 return (
662 collection_class,
663 _AssociationSet(
664 lazy_collection, creator, getter, setter, self
665 ),
666 )
667 else:
668 raise exc.ArgumentError(
669 "could not guess which interface to use for "
670 'collection_class "%s" backing "%s"; specify a '
671 "proxy_factory and proxy_bulk_set manually"
672 % (self.collection_class.__name__, self.target_collection)
673 )
674
675 def _set(self, proxy, values):
676 if self.parent.proxy_bulk_set:
677 self.parent.proxy_bulk_set(proxy, values)
678 elif self.collection_class is list:
679 proxy.extend(values)
680 elif self.collection_class is dict:
681 proxy.update(values)
682 elif self.collection_class is set:
683 proxy.update(values)
684 else:
685 raise exc.ArgumentError(
686 "no proxy_bulk_set supplied for custom "
687 "collection_class implementation"
688 )
689
690 def _inflate(self, proxy):
691 creator = (
692 self.parent.creator and self.parent.creator or self.target_class
693 )
694
695 if self.parent.getset_factory:
696 getter, setter = self.parent.getset_factory(
697 self.collection_class, self
698 )
699 else:
700 getter, setter = self.parent._default_getset(self.collection_class)
701
702 proxy.creator = creator
703 proxy.getter = getter
704 proxy.setter = setter
705
706 def _criterion_exists(self, criterion=None, **kwargs):
707 is_has = kwargs.pop("is_has", None)
708
709 target_assoc = self._unwrap_target_assoc_proxy
710 if target_assoc is not None:
711 inner = target_assoc._criterion_exists(
712 criterion=criterion, **kwargs
713 )
714 return self._comparator._criterion_exists(inner)
715
716 if self._target_is_object:
717 prop = getattr(self.target_class, self.value_attr)
718 value_expr = prop._criterion_exists(criterion, **kwargs)
719 else:
720 if kwargs:
721 raise exc.ArgumentError(
722 "Can't apply keyword arguments to column-targeted "
723 "association proxy; use =="
724 )
725 elif is_has and criterion is not None:
726 raise exc.ArgumentError(
727 "Non-empty has() not allowed for "
728 "column-targeted association proxy; use =="
729 )
730
731 value_expr = criterion
732
733 return self._comparator._criterion_exists(value_expr)
734
735 def any(self, criterion=None, **kwargs):
736 """Produce a proxied 'any' expression using EXISTS.
737
738 This expression will be a composed product
739 using the :meth:`.RelationshipProperty.Comparator.any`
740 and/or :meth:`.RelationshipProperty.Comparator.has`
741 operators of the underlying proxied attributes.
742
743 """
744 if self._unwrap_target_assoc_proxy is None and (
745 self.scalar
746 and (not self._target_is_object or self._value_is_scalar)
747 ):
748 raise exc.InvalidRequestError(
749 "'any()' not implemented for scalar " "attributes. Use has()."
750 )
751 return self._criterion_exists(
752 criterion=criterion, is_has=False, **kwargs
753 )
754
755 def has(self, criterion=None, **kwargs):
756 """Produce a proxied 'has' expression using EXISTS.
757
758 This expression will be a composed product
759 using the :meth:`.RelationshipProperty.Comparator.any`
760 and/or :meth:`.RelationshipProperty.Comparator.has`
761 operators of the underlying proxied attributes.
762
763 """
764 if self._unwrap_target_assoc_proxy is None and (
765 not self.scalar
766 or (self._target_is_object and not self._value_is_scalar)
767 ):
768 raise exc.InvalidRequestError(
769 "'has()' not implemented for collections. " "Use any()."
770 )
771 return self._criterion_exists(
772 criterion=criterion, is_has=True, **kwargs
773 )
774
775 def __repr__(self):
776 return "%s(%r)" % (self.__class__.__name__, self.parent)
777
778
779class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
780 """an :class:`.AssociationProxyInstance` where we cannot determine
781 the type of target object.
782 """
783
784 _is_canonical = False
785
786 def _ambiguous(self):
787 raise AttributeError(
788 "Association proxy %s.%s refers to an attribute '%s' that is not "
789 "directly mapped on class %s; therefore this operation cannot "
790 "proceed since we don't know what type of object is referred "
791 "towards"
792 % (
793 self.owning_class.__name__,
794 self.target_collection,
795 self.value_attr,
796 self.target_class,
797 )
798 )
799
800 def get(self, obj):
801 if obj is None:
802 return self
803 else:
804 return super(AmbiguousAssociationProxyInstance, self).get(obj)
805
806 def __eq__(self, obj):
807 self._ambiguous()
808
809 def __ne__(self, obj):
810 self._ambiguous()
811
812 def any(self, criterion=None, **kwargs):
813 self._ambiguous()
814
815 def has(self, criterion=None, **kwargs):
816 self._ambiguous()
817
818 @util.memoized_property
819 def _lookup_cache(self):
820 # mapping of <subclass>->AssociationProxyInstance.
821 # e.g. proxy is A-> A.b -> B -> B.b_attr, but B.b_attr doesn't exist;
822 # only B1(B) and B2(B) have "b_attr", keys in here would be B1, B2
823 return {}
824
825 def _non_canonical_get_for_object(self, parent_instance):
826 if parent_instance is not None:
827 actual_obj = getattr(parent_instance, self.target_collection)
828 if actual_obj is not None:
829 try:
830 insp = inspect(actual_obj)
831 except exc.NoInspectionAvailable:
832 pass
833 else:
834 mapper = insp.mapper
835 instance_class = mapper.class_
836 if instance_class not in self._lookup_cache:
837 self._populate_cache(instance_class, mapper)
838
839 try:
840 return self._lookup_cache[instance_class]
841 except KeyError:
842 pass
843
844 # no object or ambiguous object given, so return "self", which
845 # is a proxy with generally only instance-level functionality
846 return self
847
848 def _populate_cache(self, instance_class, mapper):
849 prop = orm.class_mapper(self.owning_class).get_property(
850 self.target_collection
851 )
852
853 if mapper.isa(prop.mapper):
854 target_class = instance_class
855 try:
856 target_assoc = self._cls_unwrap_target_assoc_proxy(
857 target_class, self.value_attr
858 )
859 except AttributeError:
860 pass
861 else:
862 self._lookup_cache[instance_class] = self._construct_for_assoc(
863 target_assoc,
864 self.parent,
865 self.owning_class,
866 target_class,
867 self.value_attr,
868 )
869
870
871class ObjectAssociationProxyInstance(AssociationProxyInstance):
872 """an :class:`.AssociationProxyInstance` that has an object as a target."""
873
874 _target_is_object = True
875 _is_canonical = True
876
877 def contains(self, obj):
878 """Produce a proxied 'contains' expression using EXISTS.
879
880 This expression will be a composed product
881 using the :meth:`.RelationshipProperty.Comparator.any`,
882 :meth:`.RelationshipProperty.Comparator.has`,
883 and/or :meth:`.RelationshipProperty.Comparator.contains`
884 operators of the underlying proxied attributes.
885 """
886
887 target_assoc = self._unwrap_target_assoc_proxy
888 if target_assoc is not None:
889 return self._comparator._criterion_exists(
890 target_assoc.contains(obj)
891 if not target_assoc.scalar
892 else target_assoc == obj
893 )
894 elif (
895 self._target_is_object
896 and self.scalar
897 and not self._value_is_scalar
898 ):
899 return self._comparator.has(
900 getattr(self.target_class, self.value_attr).contains(obj)
901 )
902 elif self._target_is_object and self.scalar and self._value_is_scalar:
903 raise exc.InvalidRequestError(
904 "contains() doesn't apply to a scalar object endpoint; use =="
905 )
906 else:
907
908 return self._comparator._criterion_exists(**{self.value_attr: obj})
909
910 def __eq__(self, obj):
911 # note the has() here will fail for collections; eq_()
912 # is only allowed with a scalar.
913 if obj is None:
914 return or_(
915 self._comparator.has(**{self.value_attr: obj}),
916 self._comparator == None,
917 )
918 else:
919 return self._comparator.has(**{self.value_attr: obj})
920
921 def __ne__(self, obj):
922 # note the has() here will fail for collections; eq_()
923 # is only allowed with a scalar.
924 return self._comparator.has(
925 getattr(self.target_class, self.value_attr) != obj
926 )
927
928
929class ColumnAssociationProxyInstance(
930 ColumnOperators, AssociationProxyInstance
931):
932 """an :class:`.AssociationProxyInstance` that has a database column as a
933 target.
934 """
935
936 _target_is_object = False
937 _is_canonical = True
938
939 def __eq__(self, other):
940 # special case "is None" to check for no related row as well
941 expr = self._criterion_exists(
942 self.remote_attr.operate(operator.eq, other)
943 )
944 if other is None:
945 return or_(expr, self._comparator == None)
946 else:
947 return expr
948
949 def operate(self, op, *other, **kwargs):
950 return self._criterion_exists(
951 self.remote_attr.operate(op, *other, **kwargs)
952 )
953
954
955class _lazy_collection(object):
956 def __init__(self, obj, target):
957 self.parent = obj
958 self.target = target
959
960 def __call__(self):
961 return getattr(self.parent, self.target)
962
963 def __getstate__(self):
964 return {"obj": self.parent, "target": self.target}
965
966 def __setstate__(self, state):
967 self.parent = state["obj"]
968 self.target = state["target"]
969
970
971class _AssociationCollection(object):
972 def __init__(self, lazy_collection, creator, getter, setter, parent):
973 """Constructs an _AssociationCollection.
974
975 This will always be a subclass of either _AssociationList,
976 _AssociationSet, or _AssociationDict.
977
978 lazy_collection
979 A callable returning a list-based collection of entities (usually an
980 object attribute managed by a SQLAlchemy relationship())
981
982 creator
983 A function that creates new target entities. Given one parameter:
984 value. This assertion is assumed::
985
986 obj = creator(somevalue)
987 assert getter(obj) == somevalue
988
989 getter
990 A function. Given an associated object, return the 'value'.
991
992 setter
993 A function. Given an associated object and a value, store that
994 value on the object.
995
996 """
997 self.lazy_collection = lazy_collection
998 self.creator = creator
999 self.getter = getter
1000 self.setter = setter
1001 self.parent = parent
1002
1003 col = property(lambda self: self.lazy_collection())
1004
1005 def __len__(self):
1006 return len(self.col)
1007
1008 def __bool__(self):
1009 return bool(self.col)
1010
1011 __nonzero__ = __bool__
1012
1013 def __getstate__(self):
1014 return {"parent": self.parent, "lazy_collection": self.lazy_collection}
1015
1016 def __setstate__(self, state):
1017 self.parent = state["parent"]
1018 self.lazy_collection = state["lazy_collection"]
1019 self.parent._inflate(self)
1020
1021 def _bulk_replace(self, assoc_proxy, values):
1022 self.clear()
1023 assoc_proxy._set(self, values)
1024
1025
1026class _AssociationList(_AssociationCollection):
1027 """Generic, converting, list-to-list proxy."""
1028
1029 def _create(self, value):
1030 return self.creator(value)
1031
1032 def _get(self, object_):
1033 return self.getter(object_)
1034
1035 def _set(self, object_, value):
1036 return self.setter(object_, value)
1037
1038 def __getitem__(self, index):
1039 if not isinstance(index, slice):
1040 return self._get(self.col[index])
1041 else:
1042 return [self._get(member) for member in self.col[index]]
1043
1044 def __setitem__(self, index, value):
1045 if not isinstance(index, slice):
1046 self._set(self.col[index], value)
1047 else:
1048 if index.stop is None:
1049 stop = len(self)
1050 elif index.stop < 0:
1051 stop = len(self) + index.stop
1052 else:
1053 stop = index.stop
1054 step = index.step or 1
1055
1056 start = index.start or 0
1057 rng = list(range(index.start or 0, stop, step))
1058 if step == 1:
1059 for i in rng:
1060 del self[start]
1061 i = start
1062 for item in value:
1063 self.insert(i, item)
1064 i += 1
1065 else:
1066 if len(value) != len(rng):
1067 raise ValueError(
1068 "attempt to assign sequence of size %s to "
1069 "extended slice of size %s" % (len(value), len(rng))
1070 )
1071 for i, item in zip(rng, value):
1072 self._set(self.col[i], item)
1073
1074 def __delitem__(self, index):
1075 del self.col[index]
1076
1077 def __contains__(self, value):
1078 for member in self.col:
1079 # testlib.pragma exempt:__eq__
1080 if self._get(member) == value:
1081 return True
1082 return False
1083
1084 def __getslice__(self, start, end):
1085 return [self._get(member) for member in self.col[start:end]]
1086
1087 def __setslice__(self, start, end, values):
1088 members = [self._create(v) for v in values]
1089 self.col[start:end] = members
1090
1091 def __delslice__(self, start, end):
1092 del self.col[start:end]
1093
1094 def __iter__(self):
1095 """Iterate over proxied values.
1096
1097 For the actual domain objects, iterate over .col instead or
1098 just use the underlying collection directly from its property
1099 on the parent.
1100 """
1101
1102 for member in self.col:
1103 yield self._get(member)
1104 return
1105
1106 def append(self, value):
1107 col = self.col
1108 item = self._create(value)
1109 col.append(item)
1110
1111 def count(self, value):
1112 return sum(
1113 [
1114 1
1115 for _ in util.itertools_filter(
1116 lambda v: v == value, iter(self)
1117 )
1118 ]
1119 )
1120
1121 def extend(self, values):
1122 for v in values:
1123 self.append(v)
1124
1125 def insert(self, index, value):
1126 self.col[index:index] = [self._create(value)]
1127
1128 def pop(self, index=-1):
1129 return self.getter(self.col.pop(index))
1130
1131 def remove(self, value):
1132 for i, val in enumerate(self):
1133 if val == value:
1134 del self.col[i]
1135 return
1136 raise ValueError("value not in list")
1137
1138 def reverse(self):
1139 """Not supported, use reversed(mylist)"""
1140
1141 raise NotImplementedError
1142
1143 def sort(self):
1144 """Not supported, use sorted(mylist)"""
1145
1146 raise NotImplementedError
1147
1148 def clear(self):
1149 del self.col[0 : len(self.col)]
1150
1151 def __eq__(self, other):
1152 return list(self) == other
1153
1154 def __ne__(self, other):
1155 return list(self) != other
1156
1157 def __lt__(self, other):
1158 return list(self) < other
1159
1160 def __le__(self, other):
1161 return list(self) <= other
1162
1163 def __gt__(self, other):
1164 return list(self) > other
1165
1166 def __ge__(self, other):
1167 return list(self) >= other
1168
1169 def __cmp__(self, other):
1170 return util.cmp(list(self), other)
1171
1172 def __add__(self, iterable):
1173 try:
1174 other = list(iterable)
1175 except TypeError:
1176 return NotImplemented
1177 return list(self) + other
1178
1179 def __radd__(self, iterable):
1180 try:
1181 other = list(iterable)
1182 except TypeError:
1183 return NotImplemented
1184 return other + list(self)
1185
1186 def __mul__(self, n):
1187 if not isinstance(n, int):
1188 return NotImplemented
1189 return list(self) * n
1190
1191 __rmul__ = __mul__
1192
1193 def __iadd__(self, iterable):
1194 self.extend(iterable)
1195 return self
1196
1197 def __imul__(self, n):
1198 # unlike a regular list *=, proxied __imul__ will generate unique
1199 # backing objects for each copy. *= on proxied lists is a bit of
1200 # a stretch anyhow, and this interpretation of the __imul__ contract
1201 # is more plausibly useful than copying the backing objects.
1202 if not isinstance(n, int):
1203 return NotImplemented
1204 if n == 0:
1205 self.clear()
1206 elif n > 1:
1207 self.extend(list(self) * (n - 1))
1208 return self
1209
1210 def index(self, item, *args):
1211 return list(self).index(item, *args)
1212
1213 def copy(self):
1214 return list(self)
1215
1216 def __repr__(self):
1217 return repr(list(self))
1218
1219 def __hash__(self):
1220 raise TypeError("%s objects are unhashable" % type(self).__name__)
1221
1222 for func_name, func in list(locals().items()):
1223 if (
1224 callable(func)
1225 and func.__name__ == func_name
1226 and not func.__doc__
1227 and hasattr(list, func_name)
1228 ):
1229 func.__doc__ = getattr(list, func_name).__doc__
1230 del func_name, func
1231
1232
1233_NotProvided = util.symbol("_NotProvided")
1234
1235
1236class _AssociationDict(_AssociationCollection):
1237 """Generic, converting, dict-to-dict proxy."""
1238
1239 def _create(self, key, value):
1240 return self.creator(key, value)
1241
1242 def _get(self, object_):
1243 return self.getter(object_)
1244
1245 def _set(self, object_, key, value):
1246 return self.setter(object_, key, value)
1247
1248 def __getitem__(self, key):
1249 return self._get(self.col[key])
1250
1251 def __setitem__(self, key, value):
1252 if key in self.col:
1253 self._set(self.col[key], key, value)
1254 else:
1255 self.col[key] = self._create(key, value)
1256
1257 def __delitem__(self, key):
1258 del self.col[key]
1259
1260 def __contains__(self, key):
1261 # testlib.pragma exempt:__hash__
1262 return key in self.col
1263
1264 def has_key(self, key):
1265 # testlib.pragma exempt:__hash__
1266 return key in self.col
1267
1268 def __iter__(self):
1269 return iter(self.col.keys())
1270
1271 def clear(self):
1272 self.col.clear()
1273
1274 def __eq__(self, other):
1275 return dict(self) == other
1276
1277 def __ne__(self, other):
1278 return dict(self) != other
1279
1280 def __lt__(self, other):
1281 return dict(self) < other
1282
1283 def __le__(self, other):
1284 return dict(self) <= other
1285
1286 def __gt__(self, other):
1287 return dict(self) > other
1288
1289 def __ge__(self, other):
1290 return dict(self) >= other
1291
1292 def __cmp__(self, other):
1293 return util.cmp(dict(self), other)
1294
1295 def __repr__(self):
1296 return repr(dict(self.items()))
1297
1298 def get(self, key, default=None):
1299 try:
1300 return self[key]
1301 except KeyError:
1302 return default
1303
1304 def setdefault(self, key, default=None):
1305 if key not in self.col:
1306 self.col[key] = self._create(key, default)
1307 return default
1308 else:
1309 return self[key]
1310
1311 def keys(self):
1312 return self.col.keys()
1313
1314 if util.py2k:
1315
1316 def iteritems(self):
1317 return ((key, self._get(self.col[key])) for key in self.col)
1318
1319 def itervalues(self):
1320 return (self._get(self.col[key]) for key in self.col)
1321
1322 def iterkeys(self):
1323 return self.col.iterkeys()
1324
1325 def values(self):
1326 return [self._get(member) for member in self.col.values()]
1327
1328 def items(self):
1329 return [(k, self._get(self.col[k])) for k in self]
1330
1331 else:
1332
1333 def items(self):
1334 return ((key, self._get(self.col[key])) for key in self.col)
1335
1336 def values(self):
1337 return (self._get(self.col[key]) for key in self.col)
1338
1339 def pop(self, key, default=_NotProvided):
1340 if default is _NotProvided:
1341 member = self.col.pop(key)
1342 else:
1343 member = self.col.pop(key, default)
1344 return self._get(member)
1345
1346 def popitem(self):
1347 item = self.col.popitem()
1348 return (item[0], self._get(item[1]))
1349
1350 def update(self, *a, **kw):
1351 if len(a) > 1:
1352 raise TypeError(
1353 "update expected at most 1 arguments, got %i" % len(a)
1354 )
1355 elif len(a) == 1:
1356 seq_or_map = a[0]
1357 # discern dict from sequence - took the advice from
1358 # https://www.voidspace.org.uk/python/articles/duck_typing.shtml
1359 # still not perfect :(
1360 if hasattr(seq_or_map, "keys"):
1361 for item in seq_or_map:
1362 self[item] = seq_or_map[item]
1363 else:
1364 try:
1365 for k, v in seq_or_map:
1366 self[k] = v
1367 except ValueError as err:
1368 util.raise_(
1369 ValueError(
1370 "dictionary update sequence "
1371 "requires 2-element tuples"
1372 ),
1373 replace_context=err,
1374 )
1375
1376 for key, value in kw:
1377 self[key] = value
1378
1379 def _bulk_replace(self, assoc_proxy, values):
1380 existing = set(self)
1381 constants = existing.intersection(values or ())
1382 additions = set(values or ()).difference(constants)
1383 removals = existing.difference(constants)
1384
1385 for key, member in values.items() or ():
1386 if key in additions:
1387 self[key] = member
1388 elif key in constants:
1389 self[key] = member
1390
1391 for key in removals:
1392 del self[key]
1393
1394 def copy(self):
1395 return dict(self.items())
1396
1397 def __hash__(self):
1398 raise TypeError("%s objects are unhashable" % type(self).__name__)
1399
1400 for func_name, func in list(locals().items()):
1401 if (
1402 callable(func)
1403 and func.__name__ == func_name
1404 and not func.__doc__
1405 and hasattr(dict, func_name)
1406 ):
1407 func.__doc__ = getattr(dict, func_name).__doc__
1408 del func_name, func
1409
1410
1411class _AssociationSet(_AssociationCollection):
1412 """Generic, converting, set-to-set proxy."""
1413
1414 def _create(self, value):
1415 return self.creator(value)
1416
1417 def _get(self, object_):
1418 return self.getter(object_)
1419
1420 def __len__(self):
1421 return len(self.col)
1422
1423 def __bool__(self):
1424 if self.col:
1425 return True
1426 else:
1427 return False
1428
1429 __nonzero__ = __bool__
1430
1431 def __contains__(self, value):
1432 for member in self.col:
1433 # testlib.pragma exempt:__eq__
1434 if self._get(member) == value:
1435 return True
1436 return False
1437
1438 def __iter__(self):
1439 """Iterate over proxied values.
1440
1441 For the actual domain objects, iterate over .col instead or just use
1442 the underlying collection directly from its property on the parent.
1443
1444 """
1445 for member in self.col:
1446 yield self._get(member)
1447 return
1448
1449 def add(self, value):
1450 if value not in self:
1451 self.col.add(self._create(value))
1452
1453 # for discard and remove, choosing a more expensive check strategy rather
1454 # than call self.creator()
1455 def discard(self, value):
1456 for member in self.col:
1457 if self._get(member) == value:
1458 self.col.discard(member)
1459 break
1460
1461 def remove(self, value):
1462 for member in self.col:
1463 if self._get(member) == value:
1464 self.col.discard(member)
1465 return
1466 raise KeyError(value)
1467
1468 def pop(self):
1469 if not self.col:
1470 raise KeyError("pop from an empty set")
1471 member = self.col.pop()
1472 return self._get(member)
1473
1474 def update(self, other):
1475 for value in other:
1476 self.add(value)
1477
1478 def _bulk_replace(self, assoc_proxy, values):
1479 existing = set(self)
1480 constants = existing.intersection(values or ())
1481 additions = set(values or ()).difference(constants)
1482 removals = existing.difference(constants)
1483
1484 appender = self.add
1485 remover = self.remove
1486
1487 for member in values or ():
1488 if member in additions:
1489 appender(member)
1490 elif member in constants:
1491 appender(member)
1492
1493 for member in removals:
1494 remover(member)
1495
1496 def __ior__(self, other):
1497 if not collections._set_binops_check_strict(self, other):
1498 return NotImplemented
1499 for value in other:
1500 self.add(value)
1501 return self
1502
1503 def _set(self):
1504 return set(iter(self))
1505
1506 def union(self, other):
1507 return set(self).union(other)
1508
1509 __or__ = union
1510
1511 def difference(self, other):
1512 return set(self).difference(other)
1513
1514 __sub__ = difference
1515
1516 def difference_update(self, other):
1517 for value in other:
1518 self.discard(value)
1519
1520 def __isub__(self, other):
1521 if not collections._set_binops_check_strict(self, other):
1522 return NotImplemented
1523 for value in other:
1524 self.discard(value)
1525 return self
1526
1527 def intersection(self, other):
1528 return set(self).intersection(other)
1529
1530 __and__ = intersection
1531
1532 def intersection_update(self, other):
1533 want, have = self.intersection(other), set(self)
1534
1535 remove, add = have - want, want - have
1536
1537 for value in remove:
1538 self.remove(value)
1539 for value in add:
1540 self.add(value)
1541
1542 def __iand__(self, other):
1543 if not collections._set_binops_check_strict(self, other):
1544 return NotImplemented
1545 want, have = self.intersection(other), set(self)
1546
1547 remove, add = have - want, want - have
1548
1549 for value in remove:
1550 self.remove(value)
1551 for value in add:
1552 self.add(value)
1553 return self
1554
1555 def symmetric_difference(self, other):
1556 return set(self).symmetric_difference(other)
1557
1558 __xor__ = symmetric_difference
1559
1560 def symmetric_difference_update(self, other):
1561 want, have = self.symmetric_difference(other), set(self)
1562
1563 remove, add = have - want, want - have
1564
1565 for value in remove:
1566 self.remove(value)
1567 for value in add:
1568 self.add(value)
1569
1570 def __ixor__(self, other):
1571 if not collections._set_binops_check_strict(self, other):
1572 return NotImplemented
1573 want, have = self.symmetric_difference(other), set(self)
1574
1575 remove, add = have - want, want - have
1576
1577 for value in remove:
1578 self.remove(value)
1579 for value in add:
1580 self.add(value)
1581 return self
1582
1583 def issubset(self, other):
1584 return set(self).issubset(other)
1585
1586 def issuperset(self, other):
1587 return set(self).issuperset(other)
1588
1589 def clear(self):
1590 self.col.clear()
1591
1592 def copy(self):
1593 return set(self)
1594
1595 def __eq__(self, other):
1596 return set(self) == other
1597
1598 def __ne__(self, other):
1599 return set(self) != other
1600
1601 def __lt__(self, other):
1602 return set(self) < other
1603
1604 def __le__(self, other):
1605 return set(self) <= other
1606
1607 def __gt__(self, other):
1608 return set(self) > other
1609
1610 def __ge__(self, other):
1611 return set(self) >= other
1612
1613 def __repr__(self):
1614 return repr(set(self))
1615
1616 def __hash__(self):
1617 raise TypeError("%s objects are unhashable" % type(self).__name__)
1618
1619 for func_name, func in list(locals().items()):
1620 if (
1621 callable(func)
1622 and func.__name__ == func_name
1623 and not func.__doc__
1624 and hasattr(set, func_name)
1625 ):
1626 func.__doc__ = getattr(set, func_name).__doc__
1627 del func_name, func