1# orm/writeonly.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Write-only collection API.
9
10This is an alternate mapped attribute style that only supports single-item
11collection mutation operations. To read the collection, a select()
12object must be executed each time.
13
14.. versionadded:: 2.0
15
16
17"""
18
19from __future__ import annotations
20
21from typing import Any
22from typing import Collection
23from typing import Dict
24from typing import Generic
25from typing import Iterable
26from typing import Iterator
27from typing import List
28from typing import NoReturn
29from typing import Optional
30from typing import overload
31from typing import Tuple
32from typing import Type
33from typing import TYPE_CHECKING
34from typing import TypeVar
35from typing import Union
36
37from sqlalchemy.sql import bindparam
38from . import attributes
39from . import interfaces
40from . import relationships
41from . import strategies
42from .base import NEVER_SET
43from .base import object_mapper
44from .base import PassiveFlag
45from .base import RelationshipDirection
46from .. import exc
47from .. import inspect
48from .. import log
49from .. import util
50from ..sql import delete
51from ..sql import insert
52from ..sql import select
53from ..sql import update
54from ..sql.dml import Delete
55from ..sql.dml import Insert
56from ..sql.dml import Update
57from ..util.typing import Literal
58
59if TYPE_CHECKING:
60 from . import QueryableAttribute
61 from ._typing import _InstanceDict
62 from .attributes import AttributeEventToken
63 from .base import LoaderCallableStatus
64 from .collections import _AdaptedCollectionProtocol
65 from .collections import CollectionAdapter
66 from .mapper import Mapper
67 from .relationships import _RelationshipOrderByArg
68 from .state import InstanceState
69 from .util import AliasedClass
70 from ..event import _Dispatch
71 from ..sql.selectable import FromClause
72 from ..sql.selectable import Select
73
74_T = TypeVar("_T", bound=Any)
75
76
77class WriteOnlyHistory(Generic[_T]):
78 """Overrides AttributeHistory to receive append/remove events directly."""
79
80 unchanged_items: util.OrderedIdentitySet
81 added_items: util.OrderedIdentitySet
82 deleted_items: util.OrderedIdentitySet
83 _reconcile_collection: bool
84
85 def __init__(
86 self,
87 attr: WriteOnlyAttributeImpl,
88 state: InstanceState[_T],
89 passive: PassiveFlag,
90 apply_to: Optional[WriteOnlyHistory[_T]] = None,
91 ) -> None:
92 if apply_to:
93 if passive & PassiveFlag.SQL_OK:
94 raise exc.InvalidRequestError(
95 f"Attribute {attr} can't load the existing state from the "
96 "database for this operation; full iteration is not "
97 "permitted. If this is a delete operation, configure "
98 f"passive_deletes=True on the {attr} relationship in "
99 "order to resolve this error."
100 )
101
102 self.unchanged_items = apply_to.unchanged_items
103 self.added_items = apply_to.added_items
104 self.deleted_items = apply_to.deleted_items
105 self._reconcile_collection = apply_to._reconcile_collection
106 else:
107 self.deleted_items = util.OrderedIdentitySet()
108 self.added_items = util.OrderedIdentitySet()
109 self.unchanged_items = util.OrderedIdentitySet()
110 self._reconcile_collection = False
111
112 @property
113 def added_plus_unchanged(self) -> List[_T]:
114 return list(self.added_items.union(self.unchanged_items))
115
116 @property
117 def all_items(self) -> List[_T]:
118 return list(
119 self.added_items.union(self.unchanged_items).union(
120 self.deleted_items
121 )
122 )
123
124 def as_history(self) -> attributes.History:
125 if self._reconcile_collection:
126 added = self.added_items.difference(self.unchanged_items)
127 deleted = self.deleted_items.intersection(self.unchanged_items)
128 unchanged = self.unchanged_items.difference(deleted)
129 else:
130 added, unchanged, deleted = (
131 self.added_items,
132 self.unchanged_items,
133 self.deleted_items,
134 )
135 return attributes.History(list(added), list(unchanged), list(deleted))
136
137 def indexed(self, index: Union[int, slice]) -> Union[List[_T], _T]:
138 return list(self.added_items)[index]
139
140 def add_added(self, value: _T) -> None:
141 self.added_items.add(value)
142
143 def add_removed(self, value: _T) -> None:
144 if value in self.added_items:
145 self.added_items.remove(value)
146 else:
147 self.deleted_items.add(value)
148
149
150class WriteOnlyAttributeImpl(
151 attributes.HasCollectionAdapter, attributes.AttributeImpl
152):
153 uses_objects: bool = True
154 default_accepts_scalar_loader: bool = False
155 supports_population: bool = False
156 _supports_dynamic_iteration: bool = False
157 collection: bool = False
158 dynamic: bool = True
159 order_by: _RelationshipOrderByArg = ()
160 collection_history_cls: Type[WriteOnlyHistory[Any]] = WriteOnlyHistory
161
162 query_class: Type[WriteOnlyCollection[Any]]
163
164 def __init__(
165 self,
166 class_: Union[Type[Any], AliasedClass[Any]],
167 key: str,
168 dispatch: _Dispatch[QueryableAttribute[Any]],
169 target_mapper: Mapper[_T],
170 order_by: _RelationshipOrderByArg,
171 **kw: Any,
172 ):
173 super().__init__(class_, key, None, dispatch, **kw)
174 self.target_mapper = target_mapper
175 self.query_class = WriteOnlyCollection
176 if order_by:
177 self.order_by = tuple(order_by)
178
179 def get(
180 self,
181 state: InstanceState[Any],
182 dict_: _InstanceDict,
183 passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
184 ) -> Union[util.OrderedIdentitySet, WriteOnlyCollection[Any]]:
185 if not passive & PassiveFlag.SQL_OK:
186 return self._get_collection_history(
187 state, PassiveFlag.PASSIVE_NO_INITIALIZE
188 ).added_items
189 else:
190 return self.query_class(self, state)
191
192 @overload
193 def get_collection(
194 self,
195 state: InstanceState[Any],
196 dict_: _InstanceDict,
197 user_data: Literal[None] = ...,
198 passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
199 ) -> CollectionAdapter: ...
200
201 @overload
202 def get_collection(
203 self,
204 state: InstanceState[Any],
205 dict_: _InstanceDict,
206 user_data: _AdaptedCollectionProtocol = ...,
207 passive: PassiveFlag = ...,
208 ) -> CollectionAdapter: ...
209
210 @overload
211 def get_collection(
212 self,
213 state: InstanceState[Any],
214 dict_: _InstanceDict,
215 user_data: Optional[_AdaptedCollectionProtocol] = ...,
216 passive: PassiveFlag = ...,
217 ) -> Union[
218 Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
219 ]: ...
220
221 def get_collection(
222 self,
223 state: InstanceState[Any],
224 dict_: _InstanceDict,
225 user_data: Optional[_AdaptedCollectionProtocol] = None,
226 passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
227 ) -> Union[
228 Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
229 ]:
230 data: Collection[Any]
231 if not passive & PassiveFlag.SQL_OK:
232 data = self._get_collection_history(state, passive).added_items
233 else:
234 history = self._get_collection_history(state, passive)
235 data = history.added_plus_unchanged
236 return DynamicCollectionAdapter(data) # type: ignore[return-value]
237
238 @util.memoized_property
239 def _append_token( # type:ignore[override]
240 self,
241 ) -> attributes.AttributeEventToken:
242 return attributes.AttributeEventToken(self, attributes.OP_APPEND)
243
244 @util.memoized_property
245 def _remove_token( # type:ignore[override]
246 self,
247 ) -> attributes.AttributeEventToken:
248 return attributes.AttributeEventToken(self, attributes.OP_REMOVE)
249
250 def fire_append_event(
251 self,
252 state: InstanceState[Any],
253 dict_: _InstanceDict,
254 value: Any,
255 initiator: Optional[AttributeEventToken],
256 collection_history: Optional[WriteOnlyHistory[Any]] = None,
257 ) -> None:
258 if collection_history is None:
259 collection_history = self._modified_event(state, dict_)
260
261 collection_history.add_added(value)
262
263 for fn in self.dispatch.append:
264 value = fn(state, value, initiator or self._append_token)
265
266 if self.trackparent and value is not None:
267 self.sethasparent(attributes.instance_state(value), state, True)
268
269 def fire_remove_event(
270 self,
271 state: InstanceState[Any],
272 dict_: _InstanceDict,
273 value: Any,
274 initiator: Optional[AttributeEventToken],
275 collection_history: Optional[WriteOnlyHistory[Any]] = None,
276 ) -> None:
277 if collection_history is None:
278 collection_history = self._modified_event(state, dict_)
279
280 collection_history.add_removed(value)
281
282 if self.trackparent and value is not None:
283 self.sethasparent(attributes.instance_state(value), state, False)
284
285 for fn in self.dispatch.remove:
286 fn(state, value, initiator or self._remove_token)
287
288 def _modified_event(
289 self, state: InstanceState[Any], dict_: _InstanceDict
290 ) -> WriteOnlyHistory[Any]:
291 if self.key not in state.committed_state:
292 state.committed_state[self.key] = self.collection_history_cls(
293 self, state, PassiveFlag.PASSIVE_NO_FETCH
294 )
295
296 state._modified_event(dict_, self, NEVER_SET)
297
298 # this is a hack to allow the entities.ComparableEntity fixture
299 # to work
300 dict_[self.key] = True
301 return state.committed_state[self.key] # type: ignore[no-any-return]
302
303 def set(
304 self,
305 state: InstanceState[Any],
306 dict_: _InstanceDict,
307 value: Any,
308 initiator: Optional[AttributeEventToken] = None,
309 passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
310 check_old: Any = None,
311 pop: bool = False,
312 _adapt: bool = True,
313 ) -> None:
314 if initiator and initiator.parent_token is self.parent_token:
315 return
316
317 if pop and value is None:
318 return
319
320 iterable = value
321 new_values = list(iterable)
322 if state.has_identity:
323 if not self._supports_dynamic_iteration:
324 raise exc.InvalidRequestError(
325 f'Collection "{self}" does not support implicit '
326 "iteration; collection replacement operations "
327 "can't be used"
328 )
329 old_collection = util.IdentitySet(
330 self.get(state, dict_, passive=passive)
331 )
332
333 collection_history = self._modified_event(state, dict_)
334 if not state.has_identity:
335 old_collection = collection_history.added_items
336 else:
337 old_collection = old_collection.union(
338 collection_history.added_items
339 )
340
341 constants = old_collection.intersection(new_values)
342 additions = util.IdentitySet(new_values).difference(constants)
343 removals = old_collection.difference(constants)
344
345 for member in new_values:
346 if member in additions:
347 self.fire_append_event(
348 state,
349 dict_,
350 member,
351 None,
352 collection_history=collection_history,
353 )
354
355 for member in removals:
356 self.fire_remove_event(
357 state,
358 dict_,
359 member,
360 None,
361 collection_history=collection_history,
362 )
363
364 def delete(self, *args: Any, **kwargs: Any) -> NoReturn:
365 raise NotImplementedError()
366
367 def set_committed_value(
368 self, state: InstanceState[Any], dict_: _InstanceDict, value: Any
369 ) -> NoReturn:
370 raise NotImplementedError(
371 "Dynamic attributes don't support collection population."
372 )
373
374 def get_history(
375 self,
376 state: InstanceState[Any],
377 dict_: _InstanceDict,
378 passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH,
379 ) -> attributes.History:
380 c = self._get_collection_history(state, passive)
381 return c.as_history()
382
383 def get_all_pending(
384 self,
385 state: InstanceState[Any],
386 dict_: _InstanceDict,
387 passive: PassiveFlag = PassiveFlag.PASSIVE_NO_INITIALIZE,
388 ) -> List[Tuple[InstanceState[Any], Any]]:
389 c = self._get_collection_history(state, passive)
390 return [(attributes.instance_state(x), x) for x in c.all_items]
391
392 def _get_collection_history(
393 self, state: InstanceState[Any], passive: PassiveFlag
394 ) -> WriteOnlyHistory[Any]:
395 c: WriteOnlyHistory[Any]
396 if self.key in state.committed_state:
397 c = state.committed_state[self.key]
398 else:
399 c = self.collection_history_cls(
400 self, state, PassiveFlag.PASSIVE_NO_FETCH
401 )
402
403 if state.has_identity and (passive & PassiveFlag.INIT_OK):
404 return self.collection_history_cls(
405 self, state, passive, apply_to=c
406 )
407 else:
408 return c
409
410 def append(
411 self,
412 state: InstanceState[Any],
413 dict_: _InstanceDict,
414 value: Any,
415 initiator: Optional[AttributeEventToken],
416 passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH,
417 ) -> None:
418 if initiator is not self:
419 self.fire_append_event(state, dict_, value, initiator)
420
421 def remove(
422 self,
423 state: InstanceState[Any],
424 dict_: _InstanceDict,
425 value: Any,
426 initiator: Optional[AttributeEventToken],
427 passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH,
428 ) -> None:
429 if initiator is not self:
430 self.fire_remove_event(state, dict_, value, initiator)
431
432 def pop(
433 self,
434 state: InstanceState[Any],
435 dict_: _InstanceDict,
436 value: Any,
437 initiator: Optional[AttributeEventToken],
438 passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH,
439 ) -> None:
440 self.remove(state, dict_, value, initiator, passive=passive)
441
442
443@log.class_logger
444@relationships.RelationshipProperty.strategy_for(lazy="write_only")
445class WriteOnlyLoader(strategies.AbstractRelationshipLoader, log.Identified):
446 impl_class = WriteOnlyAttributeImpl
447
448 def init_class_attribute(self, mapper: Mapper[Any]) -> None:
449 self.is_class_level = True
450 if not self.uselist or self.parent_property.direction not in (
451 interfaces.ONETOMANY,
452 interfaces.MANYTOMANY,
453 ):
454 raise exc.InvalidRequestError(
455 "On relationship %s, 'dynamic' loaders cannot be used with "
456 "many-to-one/one-to-one relationships and/or "
457 "uselist=False." % self.parent_property
458 )
459
460 strategies._register_attribute( # type: ignore[no-untyped-call]
461 self.parent_property,
462 mapper,
463 useobject=True,
464 impl_class=self.impl_class,
465 target_mapper=self.parent_property.mapper,
466 order_by=self.parent_property.order_by,
467 query_class=self.parent_property.query_class,
468 )
469
470
471class DynamicCollectionAdapter:
472 """simplified CollectionAdapter for internal API consistency"""
473
474 data: Collection[Any]
475
476 def __init__(self, data: Collection[Any]):
477 self.data = data
478
479 def __iter__(self) -> Iterator[Any]:
480 return iter(self.data)
481
482 def _reset_empty(self) -> None:
483 pass
484
485 def __len__(self) -> int:
486 return len(self.data)
487
488 def __bool__(self) -> bool:
489 return True
490
491
492class AbstractCollectionWriter(Generic[_T]):
493 """Virtual collection which includes append/remove methods that synchronize
494 into the attribute event system.
495
496 """
497
498 if not TYPE_CHECKING:
499 __slots__ = ()
500
501 instance: _T
502 _from_obj: Tuple[FromClause, ...]
503
504 def __init__(self, attr: WriteOnlyAttributeImpl, state: InstanceState[_T]):
505 instance = state.obj()
506 if TYPE_CHECKING:
507 assert instance
508 self.instance = instance
509 self.attr = attr
510
511 mapper = object_mapper(instance)
512 prop = mapper._props[self.attr.key]
513
514 if prop.secondary is not None:
515 # this is a hack right now. The Query only knows how to
516 # make subsequent joins() without a given left-hand side
517 # from self._from_obj[0]. We need to ensure prop.secondary
518 # is in the FROM. So we purposely put the mapper selectable
519 # in _from_obj[0] to ensure a user-defined join() later on
520 # doesn't fail, and secondary is then in _from_obj[1].
521
522 # note also, we are using the official ORM-annotated selectable
523 # from __clause_element__(), see #7868
524 self._from_obj = (prop.mapper.__clause_element__(), prop.secondary)
525 else:
526 self._from_obj = ()
527
528 self._where_criteria = (
529 prop._with_parent(instance, alias_secondary=False),
530 )
531
532 if self.attr.order_by:
533 self._order_by_clauses = self.attr.order_by
534 else:
535 self._order_by_clauses = ()
536
537 def _add_all_impl(self, iterator: Iterable[_T]) -> None:
538 for item in iterator:
539 self.attr.append(
540 attributes.instance_state(self.instance),
541 attributes.instance_dict(self.instance),
542 item,
543 None,
544 )
545
546 def _remove_impl(self, item: _T) -> None:
547 self.attr.remove(
548 attributes.instance_state(self.instance),
549 attributes.instance_dict(self.instance),
550 item,
551 None,
552 )
553
554
555class WriteOnlyCollection(AbstractCollectionWriter[_T]):
556 """Write-only collection which can synchronize changes into the
557 attribute event system.
558
559 The :class:`.WriteOnlyCollection` is used in a mapping by
560 using the ``"write_only"`` lazy loading strategy with
561 :func:`_orm.relationship`. For background on this configuration,
562 see :ref:`write_only_relationship`.
563
564 .. versionadded:: 2.0
565
566 .. seealso::
567
568 :ref:`write_only_relationship`
569
570 """
571
572 __slots__ = (
573 "instance",
574 "attr",
575 "_where_criteria",
576 "_from_obj",
577 "_order_by_clauses",
578 )
579
580 def __iter__(self) -> NoReturn:
581 raise TypeError(
582 "WriteOnly collections don't support iteration in-place; "
583 "to query for collection items, use the select() method to "
584 "produce a SQL statement and execute it with session.scalars()."
585 )
586
587 def select(self) -> Select[_T]:
588 """Produce a :class:`_sql.Select` construct that represents the
589 rows within this instance-local :class:`_orm.WriteOnlyCollection`.
590
591 """
592 stmt = select(self.attr.target_mapper).where(*self._where_criteria)
593 if self._from_obj:
594 stmt = stmt.select_from(*self._from_obj)
595 if self._order_by_clauses:
596 stmt = stmt.order_by(*self._order_by_clauses)
597 return stmt
598
599 def insert(self) -> Insert:
600 """For one-to-many collections, produce a :class:`_dml.Insert` which
601 will insert new rows in terms of this this instance-local
602 :class:`_orm.WriteOnlyCollection`.
603
604 This construct is only supported for a :class:`_orm.Relationship`
605 that does **not** include the :paramref:`_orm.relationship.secondary`
606 parameter. For relationships that refer to a many-to-many table,
607 use ordinary bulk insert techniques to produce new objects, then
608 use :meth:`_orm.AbstractCollectionWriter.add_all` to associate them
609 with the collection.
610
611
612 """
613
614 state = inspect(self.instance)
615 mapper = state.mapper
616 prop = mapper._props[self.attr.key]
617
618 if prop.direction is not RelationshipDirection.ONETOMANY:
619 raise exc.InvalidRequestError(
620 "Write only bulk INSERT only supported for one-to-many "
621 "collections; for many-to-many, use a separate bulk "
622 "INSERT along with add_all()."
623 )
624
625 dict_: Dict[str, Any] = {}
626
627 for l, r in prop.synchronize_pairs:
628 fn = prop._get_attr_w_warn_on_none(
629 mapper,
630 state,
631 state.dict,
632 l,
633 )
634
635 dict_[r.key] = bindparam(None, callable_=fn)
636
637 return insert(self.attr.target_mapper).values(**dict_)
638
639 def update(self) -> Update:
640 """Produce a :class:`_dml.Update` which will refer to rows in terms
641 of this instance-local :class:`_orm.WriteOnlyCollection`.
642
643 """
644 return update(self.attr.target_mapper).where(*self._where_criteria)
645
646 def delete(self) -> Delete:
647 """Produce a :class:`_dml.Delete` which will refer to rows in terms
648 of this instance-local :class:`_orm.WriteOnlyCollection`.
649
650 """
651 return delete(self.attr.target_mapper).where(*self._where_criteria)
652
653 def add_all(self, iterator: Iterable[_T]) -> None:
654 """Add an iterable of items to this :class:`_orm.WriteOnlyCollection`.
655
656 The given items will be persisted to the database in terms of
657 the parent instance's collection on the next flush.
658
659 """
660 self._add_all_impl(iterator)
661
662 def add(self, item: _T) -> None:
663 """Add an item to this :class:`_orm.WriteOnlyCollection`.
664
665 The given item will be persisted to the database in terms of
666 the parent instance's collection on the next flush.
667
668 """
669 self._add_all_impl([item])
670
671 def remove(self, item: _T) -> None:
672 """Remove an item from this :class:`_orm.WriteOnlyCollection`.
673
674 The given item will be removed from the parent instance's collection on
675 the next flush.
676
677 """
678 self._remove_impl(item)