1# orm/writeonly.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Write-only collection API.
9
10This is an alternate mapped attribute style that only supports single-item
11collection mutation operations. To read the collection, a select()
12object must be executed each time.
13
14.. versionadded:: 2.0
15
16
17"""
18
19from __future__ import annotations
20
21from typing import Any
22from typing import Collection
23from typing import Dict
24from typing import Generic
25from typing import Iterable
26from typing import Iterator
27from typing import List
28from typing import NoReturn
29from typing import Optional
30from typing import overload
31from typing import Tuple
32from typing import Type
33from typing import TYPE_CHECKING
34from typing import TypeVar
35from typing import Union
36
37from sqlalchemy.sql import bindparam
38from . import attributes
39from . import interfaces
40from . import relationships
41from . import strategies
42from .base import NEVER_SET
43from .base import object_mapper
44from .base import PassiveFlag
45from .base import RelationshipDirection
46from .. import exc
47from .. import inspect
48from .. import log
49from .. import util
50from ..sql import delete
51from ..sql import insert
52from ..sql import select
53from ..sql import update
54from ..sql.dml import Delete
55from ..sql.dml import Insert
56from ..sql.dml import Update
57from ..util.typing import Literal
58
59if TYPE_CHECKING:
60 from . import QueryableAttribute
61 from ._typing import _InstanceDict
62 from .attributes import AttributeEventToken
63 from .base import LoaderCallableStatus
64 from .collections import _AdaptedCollectionProtocol
65 from .collections import CollectionAdapter
66 from .mapper import Mapper
67 from .relationships import _RelationshipOrderByArg
68 from .state import InstanceState
69 from .util import AliasedClass
70 from ..event import _Dispatch
71 from ..sql.selectable import FromClause
72 from ..sql.selectable import Select
73
74_T = TypeVar("_T", bound=Any)
75
76
77class WriteOnlyHistory(Generic[_T]):
78 """Overrides AttributeHistory to receive append/remove events directly."""
79
80 unchanged_items: util.OrderedIdentitySet
81 added_items: util.OrderedIdentitySet
82 deleted_items: util.OrderedIdentitySet
83 _reconcile_collection: bool
84
85 def __init__(
86 self,
87 attr: WriteOnlyAttributeImpl,
88 state: InstanceState[_T],
89 passive: PassiveFlag,
90 apply_to: Optional[WriteOnlyHistory[_T]] = None,
91 ) -> None:
92 if apply_to:
93 if passive & PassiveFlag.SQL_OK:
94 raise exc.InvalidRequestError(
95 f"Attribute {attr} can't load the existing state from the "
96 "database for this operation; full iteration is not "
97 "permitted. If this is a delete operation, configure "
98 f"passive_deletes=True on the {attr} relationship in "
99 "order to resolve this error."
100 )
101
102 self.unchanged_items = apply_to.unchanged_items
103 self.added_items = apply_to.added_items
104 self.deleted_items = apply_to.deleted_items
105 self._reconcile_collection = apply_to._reconcile_collection
106 else:
107 self.deleted_items = util.OrderedIdentitySet()
108 self.added_items = util.OrderedIdentitySet()
109 self.unchanged_items = util.OrderedIdentitySet()
110 self._reconcile_collection = False
111
112 @property
113 def added_plus_unchanged(self) -> List[_T]:
114 return list(self.added_items.union(self.unchanged_items))
115
116 @property
117 def all_items(self) -> List[_T]:
118 return list(
119 self.added_items.union(self.unchanged_items).union(
120 self.deleted_items
121 )
122 )
123
124 def as_history(self) -> attributes.History:
125 if self._reconcile_collection:
126 added = self.added_items.difference(self.unchanged_items)
127 deleted = self.deleted_items.intersection(self.unchanged_items)
128 unchanged = self.unchanged_items.difference(deleted)
129 else:
130 added, unchanged, deleted = (
131 self.added_items,
132 self.unchanged_items,
133 self.deleted_items,
134 )
135 return attributes.History(list(added), list(unchanged), list(deleted))
136
137 def indexed(self, index: Union[int, slice]) -> Union[List[_T], _T]:
138 return list(self.added_items)[index]
139
140 def add_added(self, value: _T) -> None:
141 self.added_items.add(value)
142
143 def add_removed(self, value: _T) -> None:
144 if value in self.added_items:
145 self.added_items.remove(value)
146 else:
147 self.deleted_items.add(value)
148
149
150class WriteOnlyAttributeImpl(
151 attributes.HasCollectionAdapter, attributes.AttributeImpl
152):
153 uses_objects: bool = True
154 default_accepts_scalar_loader: bool = False
155 supports_population: bool = False
156 _supports_dynamic_iteration: bool = False
157 collection: bool = False
158 dynamic: bool = True
159 order_by: _RelationshipOrderByArg = ()
160 collection_history_cls: Type[WriteOnlyHistory[Any]] = WriteOnlyHistory
161
162 query_class: Type[WriteOnlyCollection[Any]]
163
164 def __init__(
165 self,
166 class_: Union[Type[Any], AliasedClass[Any]],
167 key: str,
168 dispatch: _Dispatch[QueryableAttribute[Any]],
169 target_mapper: Mapper[_T],
170 order_by: _RelationshipOrderByArg,
171 **kw: Any,
172 ):
173 super().__init__(class_, key, None, dispatch, **kw)
174 self.target_mapper = target_mapper
175 self.query_class = WriteOnlyCollection
176 if order_by:
177 self.order_by = tuple(order_by)
178
179 def get(
180 self,
181 state: InstanceState[Any],
182 dict_: _InstanceDict,
183 passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
184 ) -> Union[util.OrderedIdentitySet, WriteOnlyCollection[Any]]:
185 if not passive & PassiveFlag.SQL_OK:
186 return self._get_collection_history(
187 state, PassiveFlag.PASSIVE_NO_INITIALIZE
188 ).added_items
189 else:
190 return self.query_class(self, state)
191
192 @overload
193 def get_collection(
194 self,
195 state: InstanceState[Any],
196 dict_: _InstanceDict,
197 user_data: Literal[None] = ...,
198 passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
199 ) -> CollectionAdapter: ...
200
201 @overload
202 def get_collection(
203 self,
204 state: InstanceState[Any],
205 dict_: _InstanceDict,
206 user_data: _AdaptedCollectionProtocol = ...,
207 passive: PassiveFlag = ...,
208 ) -> CollectionAdapter: ...
209
210 @overload
211 def get_collection(
212 self,
213 state: InstanceState[Any],
214 dict_: _InstanceDict,
215 user_data: Optional[_AdaptedCollectionProtocol] = ...,
216 passive: PassiveFlag = ...,
217 ) -> Union[
218 Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
219 ]: ...
220
221 def get_collection(
222 self,
223 state: InstanceState[Any],
224 dict_: _InstanceDict,
225 user_data: Optional[_AdaptedCollectionProtocol] = None,
226 passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
227 ) -> Union[
228 Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
229 ]:
230 data: Collection[Any]
231 if not passive & PassiveFlag.SQL_OK:
232 data = self._get_collection_history(state, passive).added_items
233 else:
234 history = self._get_collection_history(state, passive)
235 data = history.added_plus_unchanged
236 return DynamicCollectionAdapter(data) # type: ignore[return-value]
237
238 @util.memoized_property
239 def _append_token(self) -> attributes.AttributeEventToken:
240 return attributes.AttributeEventToken(self, attributes.OP_APPEND)
241
242 @util.memoized_property
243 def _remove_token(self) -> attributes.AttributeEventToken:
244 return attributes.AttributeEventToken(self, attributes.OP_REMOVE)
245
246 def fire_append_event(
247 self,
248 state: InstanceState[Any],
249 dict_: _InstanceDict,
250 value: Any,
251 initiator: Optional[AttributeEventToken],
252 collection_history: Optional[WriteOnlyHistory[Any]] = None,
253 ) -> None:
254 if collection_history is None:
255 collection_history = self._modified_event(state, dict_)
256
257 collection_history.add_added(value)
258
259 for fn in self.dispatch.append:
260 value = fn(state, value, initiator or self._append_token)
261
262 if self.trackparent and value is not None:
263 self.sethasparent(attributes.instance_state(value), state, True)
264
265 def fire_remove_event(
266 self,
267 state: InstanceState[Any],
268 dict_: _InstanceDict,
269 value: Any,
270 initiator: Optional[AttributeEventToken],
271 collection_history: Optional[WriteOnlyHistory[Any]] = None,
272 ) -> None:
273 if collection_history is None:
274 collection_history = self._modified_event(state, dict_)
275
276 collection_history.add_removed(value)
277
278 if self.trackparent and value is not None:
279 self.sethasparent(attributes.instance_state(value), state, False)
280
281 for fn in self.dispatch.remove:
282 fn(state, value, initiator or self._remove_token)
283
284 def _modified_event(
285 self, state: InstanceState[Any], dict_: _InstanceDict
286 ) -> WriteOnlyHistory[Any]:
287 if self.key not in state.committed_state:
288 state.committed_state[self.key] = self.collection_history_cls(
289 self, state, PassiveFlag.PASSIVE_NO_FETCH
290 )
291
292 state._modified_event(dict_, self, NEVER_SET)
293
294 # this is a hack to allow the entities.ComparableEntity fixture
295 # to work
296 dict_[self.key] = True
297 return state.committed_state[self.key] # type: ignore[no-any-return]
298
299 def set(
300 self,
301 state: InstanceState[Any],
302 dict_: _InstanceDict,
303 value: Any,
304 initiator: Optional[AttributeEventToken] = None,
305 passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
306 check_old: Any = None,
307 pop: bool = False,
308 _adapt: bool = True,
309 ) -> None:
310 if initiator and initiator.parent_token is self.parent_token:
311 return
312
313 if pop and value is None:
314 return
315
316 iterable = value
317 new_values = list(iterable)
318 if state.has_identity:
319 if not self._supports_dynamic_iteration:
320 raise exc.InvalidRequestError(
321 f'Collection "{self}" does not support implicit '
322 "iteration; collection replacement operations "
323 "can't be used"
324 )
325 old_collection = util.IdentitySet(
326 self.get(state, dict_, passive=passive)
327 )
328
329 collection_history = self._modified_event(state, dict_)
330 if not state.has_identity:
331 old_collection = collection_history.added_items
332 else:
333 old_collection = old_collection.union(
334 collection_history.added_items
335 )
336
337 constants = old_collection.intersection(new_values)
338 additions = util.IdentitySet(new_values).difference(constants)
339 removals = old_collection.difference(constants)
340
341 for member in new_values:
342 if member in additions:
343 self.fire_append_event(
344 state,
345 dict_,
346 member,
347 None,
348 collection_history=collection_history,
349 )
350
351 for member in removals:
352 self.fire_remove_event(
353 state,
354 dict_,
355 member,
356 None,
357 collection_history=collection_history,
358 )
359
360 def delete(self, *args: Any, **kwargs: Any) -> NoReturn:
361 raise NotImplementedError()
362
363 def set_committed_value(
364 self, state: InstanceState[Any], dict_: _InstanceDict, value: Any
365 ) -> NoReturn:
366 raise NotImplementedError(
367 "Dynamic attributes don't support collection population."
368 )
369
370 def get_history(
371 self,
372 state: InstanceState[Any],
373 dict_: _InstanceDict,
374 passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH,
375 ) -> attributes.History:
376 c = self._get_collection_history(state, passive)
377 return c.as_history()
378
379 def get_all_pending(
380 self,
381 state: InstanceState[Any],
382 dict_: _InstanceDict,
383 passive: PassiveFlag = PassiveFlag.PASSIVE_NO_INITIALIZE,
384 ) -> List[Tuple[InstanceState[Any], Any]]:
385 c = self._get_collection_history(state, passive)
386 return [(attributes.instance_state(x), x) for x in c.all_items]
387
388 def _get_collection_history(
389 self, state: InstanceState[Any], passive: PassiveFlag
390 ) -> WriteOnlyHistory[Any]:
391 c: WriteOnlyHistory[Any]
392 if self.key in state.committed_state:
393 c = state.committed_state[self.key]
394 else:
395 c = self.collection_history_cls(
396 self, state, PassiveFlag.PASSIVE_NO_FETCH
397 )
398
399 if state.has_identity and (passive & PassiveFlag.INIT_OK):
400 return self.collection_history_cls(
401 self, state, passive, apply_to=c
402 )
403 else:
404 return c
405
406 def append(
407 self,
408 state: InstanceState[Any],
409 dict_: _InstanceDict,
410 value: Any,
411 initiator: Optional[AttributeEventToken],
412 passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH,
413 ) -> None:
414 if initiator is not self:
415 self.fire_append_event(state, dict_, value, initiator)
416
417 def remove(
418 self,
419 state: InstanceState[Any],
420 dict_: _InstanceDict,
421 value: Any,
422 initiator: Optional[AttributeEventToken],
423 passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH,
424 ) -> None:
425 if initiator is not self:
426 self.fire_remove_event(state, dict_, value, initiator)
427
428 def pop(
429 self,
430 state: InstanceState[Any],
431 dict_: _InstanceDict,
432 value: Any,
433 initiator: Optional[AttributeEventToken],
434 passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH,
435 ) -> None:
436 self.remove(state, dict_, value, initiator, passive=passive)
437
438
439@log.class_logger
440@relationships.RelationshipProperty.strategy_for(lazy="write_only")
441class WriteOnlyLoader(strategies.AbstractRelationshipLoader, log.Identified):
442 impl_class = WriteOnlyAttributeImpl
443
444 def init_class_attribute(self, mapper: Mapper[Any]) -> None:
445 self.is_class_level = True
446 if not self.uselist or self.parent_property.direction not in (
447 interfaces.ONETOMANY,
448 interfaces.MANYTOMANY,
449 ):
450 raise exc.InvalidRequestError(
451 "On relationship %s, 'dynamic' loaders cannot be used with "
452 "many-to-one/one-to-one relationships and/or "
453 "uselist=False." % self.parent_property
454 )
455
456 strategies._register_attribute( # type: ignore[no-untyped-call]
457 self.parent_property,
458 mapper,
459 useobject=True,
460 impl_class=self.impl_class,
461 target_mapper=self.parent_property.mapper,
462 order_by=self.parent_property.order_by,
463 query_class=self.parent_property.query_class,
464 )
465
466
467class DynamicCollectionAdapter:
468 """simplified CollectionAdapter for internal API consistency"""
469
470 data: Collection[Any]
471
472 def __init__(self, data: Collection[Any]):
473 self.data = data
474
475 def __iter__(self) -> Iterator[Any]:
476 return iter(self.data)
477
478 def _reset_empty(self) -> None:
479 pass
480
481 def __len__(self) -> int:
482 return len(self.data)
483
484 def __bool__(self) -> bool:
485 return True
486
487
488class AbstractCollectionWriter(Generic[_T]):
489 """Virtual collection which includes append/remove methods that synchronize
490 into the attribute event system.
491
492 """
493
494 if not TYPE_CHECKING:
495 __slots__ = ()
496
497 instance: _T
498 _from_obj: Tuple[FromClause, ...]
499
500 def __init__(self, attr: WriteOnlyAttributeImpl, state: InstanceState[_T]):
501 instance = state.obj()
502 if TYPE_CHECKING:
503 assert instance
504 self.instance = instance
505 self.attr = attr
506
507 mapper = object_mapper(instance)
508 prop = mapper._props[self.attr.key]
509
510 if prop.secondary is not None:
511 # this is a hack right now. The Query only knows how to
512 # make subsequent joins() without a given left-hand side
513 # from self._from_obj[0]. We need to ensure prop.secondary
514 # is in the FROM. So we purposely put the mapper selectable
515 # in _from_obj[0] to ensure a user-defined join() later on
516 # doesn't fail, and secondary is then in _from_obj[1].
517
518 # note also, we are using the official ORM-annotated selectable
519 # from __clause_element__(), see #7868
520 self._from_obj = (prop.mapper.__clause_element__(), prop.secondary)
521 else:
522 self._from_obj = ()
523
524 self._where_criteria = (
525 prop._with_parent(instance, alias_secondary=False),
526 )
527
528 if self.attr.order_by:
529 self._order_by_clauses = self.attr.order_by
530 else:
531 self._order_by_clauses = ()
532
533 def _add_all_impl(self, iterator: Iterable[_T]) -> None:
534 for item in iterator:
535 self.attr.append(
536 attributes.instance_state(self.instance),
537 attributes.instance_dict(self.instance),
538 item,
539 None,
540 )
541
542 def _remove_impl(self, item: _T) -> None:
543 self.attr.remove(
544 attributes.instance_state(self.instance),
545 attributes.instance_dict(self.instance),
546 item,
547 None,
548 )
549
550
551class WriteOnlyCollection(AbstractCollectionWriter[_T]):
552 """Write-only collection which can synchronize changes into the
553 attribute event system.
554
555 The :class:`.WriteOnlyCollection` is used in a mapping by
556 using the ``"write_only"`` lazy loading strategy with
557 :func:`_orm.relationship`. For background on this configuration,
558 see :ref:`write_only_relationship`.
559
560 .. versionadded:: 2.0
561
562 .. seealso::
563
564 :ref:`write_only_relationship`
565
566 """
567
568 __slots__ = (
569 "instance",
570 "attr",
571 "_where_criteria",
572 "_from_obj",
573 "_order_by_clauses",
574 )
575
576 def __iter__(self) -> NoReturn:
577 raise TypeError(
578 "WriteOnly collections don't support iteration in-place; "
579 "to query for collection items, use the select() method to "
580 "produce a SQL statement and execute it with session.scalars()."
581 )
582
583 def select(self) -> Select[Tuple[_T]]:
584 """Produce a :class:`_sql.Select` construct that represents the
585 rows within this instance-local :class:`_orm.WriteOnlyCollection`.
586
587 """
588 stmt = select(self.attr.target_mapper).where(*self._where_criteria)
589 if self._from_obj:
590 stmt = stmt.select_from(*self._from_obj)
591 if self._order_by_clauses:
592 stmt = stmt.order_by(*self._order_by_clauses)
593 return stmt
594
595 def insert(self) -> Insert:
596 """For one-to-many collections, produce a :class:`_dml.Insert` which
597 will insert new rows in terms of this this instance-local
598 :class:`_orm.WriteOnlyCollection`.
599
600 This construct is only supported for a :class:`_orm.Relationship`
601 that does **not** include the :paramref:`_orm.relationship.secondary`
602 parameter. For relationships that refer to a many-to-many table,
603 use ordinary bulk insert techniques to produce new objects, then
604 use :meth:`_orm.AbstractCollectionWriter.add_all` to associate them
605 with the collection.
606
607
608 """
609
610 state = inspect(self.instance)
611 mapper = state.mapper
612 prop = mapper._props[self.attr.key]
613
614 if prop.direction is not RelationshipDirection.ONETOMANY:
615 raise exc.InvalidRequestError(
616 "Write only bulk INSERT only supported for one-to-many "
617 "collections; for many-to-many, use a separate bulk "
618 "INSERT along with add_all()."
619 )
620
621 dict_: Dict[str, Any] = {}
622
623 for l, r in prop.synchronize_pairs:
624 fn = prop._get_attr_w_warn_on_none(
625 mapper,
626 state,
627 state.dict,
628 l,
629 )
630
631 dict_[r.key] = bindparam(None, callable_=fn)
632
633 return insert(self.attr.target_mapper).values(**dict_)
634
635 def update(self) -> Update:
636 """Produce a :class:`_dml.Update` which will refer to rows in terms
637 of this instance-local :class:`_orm.WriteOnlyCollection`.
638
639 """
640 return update(self.attr.target_mapper).where(*self._where_criteria)
641
642 def delete(self) -> Delete:
643 """Produce a :class:`_dml.Delete` which will refer to rows in terms
644 of this instance-local :class:`_orm.WriteOnlyCollection`.
645
646 """
647 return delete(self.attr.target_mapper).where(*self._where_criteria)
648
649 def add_all(self, iterator: Iterable[_T]) -> None:
650 """Add an iterable of items to this :class:`_orm.WriteOnlyCollection`.
651
652 The given items will be persisted to the database in terms of
653 the parent instance's collection on the next flush.
654
655 """
656 self._add_all_impl(iterator)
657
658 def add(self, item: _T) -> None:
659 """Add an item to this :class:`_orm.WriteOnlyCollection`.
660
661 The given item will be persisted to the database in terms of
662 the parent instance's collection on the next flush.
663
664 """
665 self._add_all_impl([item])
666
667 def remove(self, item: _T) -> None:
668 """Remove an item from this :class:`_orm.WriteOnlyCollection`.
669
670 The given item will be removed from the parent instance's collection on
671 the next flush.
672
673 """
674 self._remove_impl(item)