Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy/orm/bulk_persistence.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

735 statements  

1# orm/bulk_persistence.py 

2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7# mypy: ignore-errors 

8 

9 

10"""additional ORM persistence classes related to "bulk" operations, 

11specifically outside of the flush() process. 

12 

13""" 

14 

15from __future__ import annotations 

16 

17from typing import Any 

18from typing import cast 

19from typing import Dict 

20from typing import Iterable 

21from typing import Optional 

22from typing import overload 

23from typing import TYPE_CHECKING 

24from typing import TypeVar 

25from typing import Union 

26 

27from . import attributes 

28from . import context 

29from . import evaluator 

30from . import exc as orm_exc 

31from . import loading 

32from . import persistence 

33from .base import NO_VALUE 

34from .context import AbstractORMCompileState 

35from .context import FromStatement 

36from .context import ORMFromStatementCompileState 

37from .context import QueryContext 

38from .. import exc as sa_exc 

39from .. import util 

40from ..engine import Dialect 

41from ..engine import result as _result 

42from ..sql import coercions 

43from ..sql import dml 

44from ..sql import expression 

45from ..sql import roles 

46from ..sql import select 

47from ..sql import sqltypes 

48from ..sql.base import _entity_namespace_key 

49from ..sql.base import CompileState 

50from ..sql.base import Options 

51from ..sql.dml import DeleteDMLState 

52from ..sql.dml import InsertDMLState 

53from ..sql.dml import UpdateDMLState 

54from ..util import EMPTY_DICT 

55from ..util.typing import Literal 

56 

57if TYPE_CHECKING: 

58 from ._typing import DMLStrategyArgument 

59 from ._typing import OrmExecuteOptionsParameter 

60 from ._typing import SynchronizeSessionArgument 

61 from .mapper import Mapper 

62 from .session import _BindArguments 

63 from .session import ORMExecuteState 

64 from .session import Session 

65 from .session import SessionTransaction 

66 from .state import InstanceState 

67 from ..engine import Connection 

68 from ..engine import cursor 

69 from ..engine.interfaces import _CoreAnyExecuteParams 

70 

71_O = TypeVar("_O", bound=object) 

72 

73 

74@overload 

75def _bulk_insert( 

76 mapper: Mapper[_O], 

77 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

78 session_transaction: SessionTransaction, 

79 *, 

80 isstates: bool, 

81 return_defaults: bool, 

82 render_nulls: bool, 

83 use_orm_insert_stmt: Literal[None] = ..., 

84 execution_options: Optional[OrmExecuteOptionsParameter] = ..., 

85) -> None: ... 

86 

87 

88@overload 

89def _bulk_insert( 

90 mapper: Mapper[_O], 

91 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

92 session_transaction: SessionTransaction, 

93 *, 

94 isstates: bool, 

95 return_defaults: bool, 

96 render_nulls: bool, 

97 use_orm_insert_stmt: Optional[dml.Insert] = ..., 

98 execution_options: Optional[OrmExecuteOptionsParameter] = ..., 

99) -> cursor.CursorResult[Any]: ... 

100 

101 

102def _bulk_insert( 

103 mapper: Mapper[_O], 

104 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

105 session_transaction: SessionTransaction, 

106 *, 

107 isstates: bool, 

108 return_defaults: bool, 

109 render_nulls: bool, 

110 use_orm_insert_stmt: Optional[dml.Insert] = None, 

111 execution_options: Optional[OrmExecuteOptionsParameter] = None, 

112) -> Optional[cursor.CursorResult[Any]]: 

113 base_mapper = mapper.base_mapper 

114 

115 if session_transaction.session.connection_callable: 

116 raise NotImplementedError( 

117 "connection_callable / per-instance sharding " 

118 "not supported in bulk_insert()" 

119 ) 

120 

121 if isstates: 

122 if TYPE_CHECKING: 

123 mappings = cast(Iterable[InstanceState[_O]], mappings) 

124 

125 if return_defaults: 

126 # list of states allows us to attach .key for return_defaults case 

127 states = [(state, state.dict) for state in mappings] 

128 mappings = [dict_ for (state, dict_) in states] 

129 else: 

130 mappings = [state.dict for state in mappings] 

131 else: 

132 if TYPE_CHECKING: 

133 mappings = cast(Iterable[Dict[str, Any]], mappings) 

134 

135 if return_defaults: 

136 # use dictionaries given, so that newly populated defaults 

137 # can be delivered back to the caller (see #11661). This is **not** 

138 # compatible with other use cases such as a session-executed 

139 # insert() construct, as this will confuse the case of 

140 # insert-per-subclass for joined inheritance cases (see 

141 # test_bulk_statements.py::BulkDMLReturningJoinedInhTest). 

142 # 

143 # So in this conditional, we have **only** called 

144 # session.bulk_insert_mappings() which does not have this 

145 # requirement 

146 mappings = list(mappings) 

147 else: 

148 # for all other cases we need to establish a local dictionary 

149 # so that the incoming dictionaries aren't mutated 

150 mappings = [dict(m) for m in mappings] 

151 _expand_composites(mapper, mappings) 

152 

153 connection = session_transaction.connection(base_mapper) 

154 

155 return_result: Optional[cursor.CursorResult[Any]] = None 

156 

157 mappers_to_run = [ 

158 (table, mp) 

159 for table, mp in base_mapper._sorted_tables.items() 

160 if table in mapper._pks_by_table 

161 ] 

162 

163 if return_defaults: 

164 # not used by new-style bulk inserts, only used for legacy 

165 bookkeeping = True 

166 elif len(mappers_to_run) > 1: 

167 # if we have more than one table, mapper to run where we will be 

168 # either horizontally splicing, or copying values between tables, 

169 # we need the "bookkeeping" / deterministic returning order 

170 bookkeeping = True 

171 else: 

172 bookkeeping = False 

173 

174 for table, super_mapper in mappers_to_run: 

175 # find bindparams in the statement. For bulk, we don't really know if 

176 # a key in the params applies to a different table since we are 

177 # potentially inserting for multiple tables here; looking at the 

178 # bindparam() is a lot more direct. in most cases this will 

179 # use _generate_cache_key() which is memoized, although in practice 

180 # the ultimate statement that's executed is probably not the same 

181 # object so that memoization might not matter much. 

182 extra_bp_names = ( 

183 [ 

184 b.key 

185 for b in use_orm_insert_stmt._get_embedded_bindparams() 

186 if b.key in mappings[0] 

187 ] 

188 if use_orm_insert_stmt is not None 

189 else () 

190 ) 

191 

192 records = ( 

193 ( 

194 None, 

195 state_dict, 

196 params, 

197 mapper, 

198 connection, 

199 value_params, 

200 has_all_pks, 

201 has_all_defaults, 

202 ) 

203 for ( 

204 state, 

205 state_dict, 

206 params, 

207 mp, 

208 conn, 

209 value_params, 

210 has_all_pks, 

211 has_all_defaults, 

212 ) in persistence._collect_insert_commands( 

213 table, 

214 ((None, mapping, mapper, connection) for mapping in mappings), 

215 bulk=True, 

216 return_defaults=bookkeeping, 

217 render_nulls=render_nulls, 

218 include_bulk_keys=extra_bp_names, 

219 ) 

220 ) 

221 

222 result = persistence._emit_insert_statements( 

223 base_mapper, 

224 None, 

225 super_mapper, 

226 table, 

227 records, 

228 bookkeeping=bookkeeping, 

229 use_orm_insert_stmt=use_orm_insert_stmt, 

230 execution_options=execution_options, 

231 ) 

232 if use_orm_insert_stmt is not None: 

233 if not use_orm_insert_stmt._returning or return_result is None: 

234 return_result = result 

235 elif result.returns_rows: 

236 assert bookkeeping 

237 return_result = return_result.splice_horizontally(result) 

238 

239 if return_defaults and isstates: 

240 identity_cls = mapper._identity_class 

241 identity_props = [p.key for p in mapper._identity_key_props] 

242 for state, dict_ in states: 

243 state.key = ( 

244 identity_cls, 

245 tuple([dict_[key] for key in identity_props]), 

246 None, 

247 ) 

248 

249 if use_orm_insert_stmt is not None: 

250 assert return_result is not None 

251 return return_result 

252 

253 

254@overload 

255def _bulk_update( 

256 mapper: Mapper[Any], 

257 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

258 session_transaction: SessionTransaction, 

259 *, 

260 isstates: bool, 

261 update_changed_only: bool, 

262 use_orm_update_stmt: Literal[None] = ..., 

263 enable_check_rowcount: bool = True, 

264) -> None: ... 

265 

266 

267@overload 

268def _bulk_update( 

269 mapper: Mapper[Any], 

270 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

271 session_transaction: SessionTransaction, 

272 *, 

273 isstates: bool, 

274 update_changed_only: bool, 

275 use_orm_update_stmt: Optional[dml.Update] = ..., 

276 enable_check_rowcount: bool = True, 

277) -> _result.Result[Any]: ... 

278 

279 

280def _bulk_update( 

281 mapper: Mapper[Any], 

282 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

283 session_transaction: SessionTransaction, 

284 *, 

285 isstates: bool, 

286 update_changed_only: bool, 

287 use_orm_update_stmt: Optional[dml.Update] = None, 

288 enable_check_rowcount: bool = True, 

289) -> Optional[_result.Result[Any]]: 

290 base_mapper = mapper.base_mapper 

291 

292 search_keys = mapper._primary_key_propkeys 

293 if mapper._version_id_prop: 

294 search_keys = {mapper._version_id_prop.key}.union(search_keys) 

295 

296 def _changed_dict(mapper, state): 

297 return { 

298 k: v 

299 for k, v in state.dict.items() 

300 if k in state.committed_state or k in search_keys 

301 } 

302 

303 if isstates: 

304 if update_changed_only: 

305 mappings = [_changed_dict(mapper, state) for state in mappings] 

306 else: 

307 mappings = [state.dict for state in mappings] 

308 else: 

309 mappings = [dict(m) for m in mappings] 

310 _expand_composites(mapper, mappings) 

311 

312 if session_transaction.session.connection_callable: 

313 raise NotImplementedError( 

314 "connection_callable / per-instance sharding " 

315 "not supported in bulk_update()" 

316 ) 

317 

318 connection = session_transaction.connection(base_mapper) 

319 

320 # find bindparams in the statement. see _bulk_insert for similar 

321 # notes for the insert case 

322 extra_bp_names = ( 

323 [ 

324 b.key 

325 for b in use_orm_update_stmt._get_embedded_bindparams() 

326 if b.key in mappings[0] 

327 ] 

328 if use_orm_update_stmt is not None 

329 else () 

330 ) 

331 

332 for table, super_mapper in base_mapper._sorted_tables.items(): 

333 if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: 

334 continue 

335 

336 records = persistence._collect_update_commands( 

337 None, 

338 table, 

339 ( 

340 ( 

341 None, 

342 mapping, 

343 mapper, 

344 connection, 

345 ( 

346 mapping[mapper._version_id_prop.key] 

347 if mapper._version_id_prop 

348 else None 

349 ), 

350 ) 

351 for mapping in mappings 

352 ), 

353 bulk=True, 

354 use_orm_update_stmt=use_orm_update_stmt, 

355 include_bulk_keys=extra_bp_names, 

356 ) 

357 persistence._emit_update_statements( 

358 base_mapper, 

359 None, 

360 super_mapper, 

361 table, 

362 records, 

363 bookkeeping=False, 

364 use_orm_update_stmt=use_orm_update_stmt, 

365 enable_check_rowcount=enable_check_rowcount, 

366 ) 

367 

368 if use_orm_update_stmt is not None: 

369 return _result.null_result() 

370 

371 

372def _expand_composites(mapper, mappings): 

373 composite_attrs = mapper.composites 

374 if not composite_attrs: 

375 return 

376 

377 composite_keys = set(composite_attrs.keys()) 

378 populators = { 

379 key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() 

380 for key in composite_keys 

381 } 

382 for mapping in mappings: 

383 for key in composite_keys.intersection(mapping): 

384 populators[key](mapping) 

385 

386 

387class ORMDMLState(AbstractORMCompileState): 

388 is_dml_returning = True 

389 from_statement_ctx: Optional[ORMFromStatementCompileState] = None 

390 

391 @classmethod 

392 def _get_orm_crud_kv_pairs( 

393 cls, mapper, statement, kv_iterator, needs_to_be_cacheable 

394 ): 

395 core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs 

396 

397 for k, v in kv_iterator: 

398 k = coercions.expect(roles.DMLColumnRole, k) 

399 

400 if isinstance(k, str): 

401 desc = _entity_namespace_key(mapper, k, default=NO_VALUE) 

402 if desc is NO_VALUE: 

403 yield ( 

404 coercions.expect(roles.DMLColumnRole, k), 

405 ( 

406 coercions.expect( 

407 roles.ExpressionElementRole, 

408 v, 

409 type_=sqltypes.NullType(), 

410 is_crud=True, 

411 ) 

412 if needs_to_be_cacheable 

413 else v 

414 ), 

415 ) 

416 else: 

417 yield from core_get_crud_kv_pairs( 

418 statement, 

419 desc._bulk_update_tuples(v), 

420 needs_to_be_cacheable, 

421 ) 

422 elif "entity_namespace" in k._annotations: 

423 k_anno = k._annotations 

424 attr = _entity_namespace_key( 

425 k_anno["entity_namespace"], k_anno["proxy_key"] 

426 ) 

427 yield from core_get_crud_kv_pairs( 

428 statement, 

429 attr._bulk_update_tuples(v), 

430 needs_to_be_cacheable, 

431 ) 

432 else: 

433 yield ( 

434 k, 

435 ( 

436 v 

437 if not needs_to_be_cacheable 

438 else coercions.expect( 

439 roles.ExpressionElementRole, 

440 v, 

441 type_=sqltypes.NullType(), 

442 is_crud=True, 

443 ) 

444 ), 

445 ) 

446 

447 @classmethod 

448 def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): 

449 plugin_subject = statement._propagate_attrs["plugin_subject"] 

450 

451 if not plugin_subject or not plugin_subject.mapper: 

452 return UpdateDMLState._get_multi_crud_kv_pairs( 

453 statement, kv_iterator 

454 ) 

455 

456 return [ 

457 dict( 

458 cls._get_orm_crud_kv_pairs( 

459 plugin_subject.mapper, statement, value_dict.items(), False 

460 ) 

461 ) 

462 for value_dict in kv_iterator 

463 ] 

464 

465 @classmethod 

466 def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable): 

467 assert ( 

468 needs_to_be_cacheable 

469 ), "no test coverage for needs_to_be_cacheable=False" 

470 

471 plugin_subject = statement._propagate_attrs["plugin_subject"] 

472 

473 if not plugin_subject or not plugin_subject.mapper: 

474 return UpdateDMLState._get_crud_kv_pairs( 

475 statement, kv_iterator, needs_to_be_cacheable 

476 ) 

477 

478 return list( 

479 cls._get_orm_crud_kv_pairs( 

480 plugin_subject.mapper, 

481 statement, 

482 kv_iterator, 

483 needs_to_be_cacheable, 

484 ) 

485 ) 

486 

487 @classmethod 

488 def get_entity_description(cls, statement): 

489 ext_info = statement.table._annotations["parententity"] 

490 mapper = ext_info.mapper 

491 if ext_info.is_aliased_class: 

492 _label_name = ext_info.name 

493 else: 

494 _label_name = mapper.class_.__name__ 

495 

496 return { 

497 "name": _label_name, 

498 "type": mapper.class_, 

499 "expr": ext_info.entity, 

500 "entity": ext_info.entity, 

501 "table": mapper.local_table, 

502 } 

503 

504 @classmethod 

505 def get_returning_column_descriptions(cls, statement): 

506 def _ent_for_col(c): 

507 return c._annotations.get("parententity", None) 

508 

509 def _attr_for_col(c, ent): 

510 if ent is None: 

511 return c 

512 proxy_key = c._annotations.get("proxy_key", None) 

513 if not proxy_key: 

514 return c 

515 else: 

516 return getattr(ent.entity, proxy_key, c) 

517 

518 return [ 

519 { 

520 "name": c.key, 

521 "type": c.type, 

522 "expr": _attr_for_col(c, ent), 

523 "aliased": ent.is_aliased_class, 

524 "entity": ent.entity, 

525 } 

526 for c, ent in [ 

527 (c, _ent_for_col(c)) for c in statement._all_selected_columns 

528 ] 

529 ] 

530 

531 def _setup_orm_returning( 

532 self, 

533 compiler, 

534 orm_level_statement, 

535 dml_level_statement, 

536 dml_mapper, 

537 *, 

538 use_supplemental_cols=True, 

539 ): 

540 """establish ORM column handlers for an INSERT, UPDATE, or DELETE 

541 which uses explicit returning(). 

542 

543 called within compilation level create_for_statement. 

544 

545 The _return_orm_returning() method then receives the Result 

546 after the statement was executed, and applies ORM loading to the 

547 state that we first established here. 

548 

549 """ 

550 

551 if orm_level_statement._returning: 

552 fs = FromStatement( 

553 orm_level_statement._returning, 

554 dml_level_statement, 

555 _adapt_on_names=False, 

556 ) 

557 fs = fs.execution_options(**orm_level_statement._execution_options) 

558 fs = fs.options(*orm_level_statement._with_options) 

559 self.select_statement = fs 

560 self.from_statement_ctx = fsc = ( 

561 ORMFromStatementCompileState.create_for_statement(fs, compiler) 

562 ) 

563 fsc.setup_dml_returning_compile_state(dml_mapper) 

564 

565 dml_level_statement = dml_level_statement._generate() 

566 dml_level_statement._returning = () 

567 

568 cols_to_return = [c for c in fsc.primary_columns if c is not None] 

569 

570 # since we are splicing result sets together, make sure there 

571 # are columns of some kind returned in each result set 

572 if not cols_to_return: 

573 cols_to_return.extend(dml_mapper.primary_key) 

574 

575 if use_supplemental_cols: 

576 dml_level_statement = dml_level_statement.return_defaults( 

577 # this is a little weird looking, but by passing 

578 # primary key as the main list of cols, this tells 

579 # return_defaults to omit server-default cols (and 

580 # actually all cols, due to some weird thing we should 

581 # clean up in crud.py). 

582 # Since we have cols_to_return, just return what we asked 

583 # for (plus primary key, which ORM persistence needs since 

584 # we likely set bookkeeping=True here, which is another 

585 # whole thing...). We dont want to clutter the 

586 # statement up with lots of other cols the user didn't 

587 # ask for. see #9685 

588 *dml_mapper.primary_key, 

589 supplemental_cols=cols_to_return, 

590 ) 

591 else: 

592 dml_level_statement = dml_level_statement.returning( 

593 *cols_to_return 

594 ) 

595 

596 return dml_level_statement 

597 

598 @classmethod 

599 def _return_orm_returning( 

600 cls, 

601 session, 

602 statement, 

603 params, 

604 execution_options, 

605 bind_arguments, 

606 result, 

607 ): 

608 execution_context = result.context 

609 compile_state = execution_context.compiled.compile_state 

610 

611 if ( 

612 compile_state.from_statement_ctx 

613 and not compile_state.from_statement_ctx.compile_options._is_star 

614 ): 

615 load_options = execution_options.get( 

616 "_sa_orm_load_options", QueryContext.default_load_options 

617 ) 

618 

619 querycontext = QueryContext( 

620 compile_state.from_statement_ctx, 

621 compile_state.select_statement, 

622 statement, 

623 params, 

624 session, 

625 load_options, 

626 execution_options, 

627 bind_arguments, 

628 ) 

629 return loading.instances(result, querycontext) 

630 else: 

631 return result 

632 

633 

634class BulkUDCompileState(ORMDMLState): 

635 class default_update_options(Options): 

636 _dml_strategy: DMLStrategyArgument = "auto" 

637 _synchronize_session: SynchronizeSessionArgument = "auto" 

638 _can_use_returning: bool = False 

639 _is_delete_using: bool = False 

640 _is_update_from: bool = False 

641 _autoflush: bool = True 

642 _subject_mapper: Optional[Mapper[Any]] = None 

643 _resolved_values = EMPTY_DICT 

644 _eval_condition = None 

645 _matched_rows = None 

646 _identity_token = None 

647 _populate_existing: bool = False 

648 

649 @classmethod 

650 def can_use_returning( 

651 cls, 

652 dialect: Dialect, 

653 mapper: Mapper[Any], 

654 *, 

655 is_multitable: bool = False, 

656 is_update_from: bool = False, 

657 is_delete_using: bool = False, 

658 is_executemany: bool = False, 

659 ) -> bool: 

660 raise NotImplementedError() 

661 

662 @classmethod 

663 def orm_pre_session_exec( 

664 cls, 

665 session, 

666 statement, 

667 params, 

668 execution_options, 

669 bind_arguments, 

670 is_pre_event, 

671 ): 

672 ( 

673 update_options, 

674 execution_options, 

675 ) = BulkUDCompileState.default_update_options.from_execution_options( 

676 "_sa_orm_update_options", 

677 { 

678 "synchronize_session", 

679 "autoflush", 

680 "populate_existing", 

681 "identity_token", 

682 "is_delete_using", 

683 "is_update_from", 

684 "dml_strategy", 

685 }, 

686 execution_options, 

687 statement._execution_options, 

688 ) 

689 bind_arguments["clause"] = statement 

690 try: 

691 plugin_subject = statement._propagate_attrs["plugin_subject"] 

692 except KeyError: 

693 assert False, "statement had 'orm' plugin but no plugin_subject" 

694 else: 

695 if plugin_subject: 

696 bind_arguments["mapper"] = plugin_subject.mapper 

697 update_options += {"_subject_mapper": plugin_subject.mapper} 

698 

699 if "parententity" not in statement.table._annotations: 

700 update_options += {"_dml_strategy": "core_only"} 

701 elif not isinstance(params, list): 

702 if update_options._dml_strategy == "auto": 

703 update_options += {"_dml_strategy": "orm"} 

704 elif update_options._dml_strategy == "bulk": 

705 raise sa_exc.InvalidRequestError( 

706 'Can\'t use "bulk" ORM insert strategy without ' 

707 "passing separate parameters" 

708 ) 

709 else: 

710 if update_options._dml_strategy == "auto": 

711 update_options += {"_dml_strategy": "bulk"} 

712 

713 sync = update_options._synchronize_session 

714 if sync is not None: 

715 if sync not in ("auto", "evaluate", "fetch", False): 

716 raise sa_exc.ArgumentError( 

717 "Valid strategies for session synchronization " 

718 "are 'auto', 'evaluate', 'fetch', False" 

719 ) 

720 if update_options._dml_strategy == "bulk" and sync == "fetch": 

721 raise sa_exc.InvalidRequestError( 

722 "The 'fetch' synchronization strategy is not available " 

723 "for 'bulk' ORM updates (i.e. multiple parameter sets)" 

724 ) 

725 

726 if not is_pre_event: 

727 if update_options._autoflush: 

728 session._autoflush() 

729 

730 if update_options._dml_strategy == "orm": 

731 if update_options._synchronize_session == "auto": 

732 update_options = cls._do_pre_synchronize_auto( 

733 session, 

734 statement, 

735 params, 

736 execution_options, 

737 bind_arguments, 

738 update_options, 

739 ) 

740 elif update_options._synchronize_session == "evaluate": 

741 update_options = cls._do_pre_synchronize_evaluate( 

742 session, 

743 statement, 

744 params, 

745 execution_options, 

746 bind_arguments, 

747 update_options, 

748 ) 

749 elif update_options._synchronize_session == "fetch": 

750 update_options = cls._do_pre_synchronize_fetch( 

751 session, 

752 statement, 

753 params, 

754 execution_options, 

755 bind_arguments, 

756 update_options, 

757 ) 

758 elif update_options._dml_strategy == "bulk": 

759 if update_options._synchronize_session == "auto": 

760 update_options += {"_synchronize_session": "evaluate"} 

761 

762 # indicators from the "pre exec" step that are then 

763 # added to the DML statement, which will also be part of the cache 

764 # key. The compile level create_for_statement() method will then 

765 # consume these at compiler time. 

766 statement = statement._annotate( 

767 { 

768 "synchronize_session": update_options._synchronize_session, 

769 "is_delete_using": update_options._is_delete_using, 

770 "is_update_from": update_options._is_update_from, 

771 "dml_strategy": update_options._dml_strategy, 

772 "can_use_returning": update_options._can_use_returning, 

773 } 

774 ) 

775 

776 return ( 

777 statement, 

778 util.immutabledict(execution_options).union( 

779 {"_sa_orm_update_options": update_options} 

780 ), 

781 ) 

782 

783 @classmethod 

784 def orm_setup_cursor_result( 

785 cls, 

786 session, 

787 statement, 

788 params, 

789 execution_options, 

790 bind_arguments, 

791 result, 

792 ): 

793 # this stage of the execution is called after the 

794 # do_orm_execute event hook. meaning for an extension like 

795 # horizontal sharding, this step happens *within* the horizontal 

796 # sharding event handler which calls session.execute() re-entrantly 

797 # and will occur for each backend individually. 

798 # the sharding extension then returns its own merged result from the 

799 # individual ones we return here. 

800 

801 update_options = execution_options["_sa_orm_update_options"] 

802 if update_options._dml_strategy == "orm": 

803 if update_options._synchronize_session == "evaluate": 

804 cls._do_post_synchronize_evaluate( 

805 session, statement, result, update_options 

806 ) 

807 elif update_options._synchronize_session == "fetch": 

808 cls._do_post_synchronize_fetch( 

809 session, statement, result, update_options 

810 ) 

811 elif update_options._dml_strategy == "bulk": 

812 if update_options._synchronize_session == "evaluate": 

813 cls._do_post_synchronize_bulk_evaluate( 

814 session, params, result, update_options 

815 ) 

816 return result 

817 

818 return cls._return_orm_returning( 

819 session, 

820 statement, 

821 params, 

822 execution_options, 

823 bind_arguments, 

824 result, 

825 ) 

826 

827 @classmethod 

828 def _adjust_for_extra_criteria(cls, global_attributes, ext_info): 

829 """Apply extra criteria filtering. 

830 

831 For all distinct single-table-inheritance mappers represented in the 

832 table being updated or deleted, produce additional WHERE criteria such 

833 that only the appropriate subtypes are selected from the total results. 

834 

835 Additionally, add WHERE criteria originating from LoaderCriteriaOptions 

836 collected from the statement. 

837 

838 """ 

839 

840 return_crit = () 

841 

842 adapter = ext_info._adapter if ext_info.is_aliased_class else None 

843 

844 if ( 

845 "additional_entity_criteria", 

846 ext_info.mapper, 

847 ) in global_attributes: 

848 return_crit += tuple( 

849 ae._resolve_where_criteria(ext_info) 

850 for ae in global_attributes[ 

851 ("additional_entity_criteria", ext_info.mapper) 

852 ] 

853 if ae.include_aliases or ae.entity is ext_info 

854 ) 

855 

856 if ext_info.mapper._single_table_criterion is not None: 

857 return_crit += (ext_info.mapper._single_table_criterion,) 

858 

859 if adapter: 

860 return_crit = tuple(adapter.traverse(crit) for crit in return_crit) 

861 

862 return return_crit 

863 

864 @classmethod 

865 def _interpret_returning_rows(cls, result, mapper, rows): 

866 """return rows that indicate PK cols in mapper.primary_key position 

867 for RETURNING rows. 

868 

869 Prior to 2.0.36, this method seemed to be written for some kind of 

870 inheritance scenario but the scenario was unused for actual joined 

871 inheritance, and the function instead seemed to perform some kind of 

872 partial translation that would remove non-PK cols if the PK cols 

873 happened to be first in the row, but not otherwise. The joined 

874 inheritance walk feature here seems to have never been used as it was 

875 always skipped by the "local_table" check. 

876 

877 As of 2.0.36 the function strips away non-PK cols and provides the 

878 PK cols for the table in mapper PK order. 

879 

880 """ 

881 

882 try: 

883 if mapper.local_table is not mapper.base_mapper.local_table: 

884 # TODO: dive more into how a local table PK is used for fetch 

885 # sync, not clear if this is correct as it depends on the 

886 # downstream routine to fetch rows using 

887 # local_table.primary_key order 

888 pk_keys = result._tuple_getter(mapper.local_table.primary_key) 

889 else: 

890 pk_keys = result._tuple_getter(mapper.primary_key) 

891 except KeyError: 

892 # can't use these rows, they don't have PK cols in them 

893 # this is an unusual case where the user would have used 

894 # .return_defaults() 

895 return [] 

896 

897 return [pk_keys(row) for row in rows] 

898 

899 @classmethod 

900 def _get_matched_objects_on_criteria(cls, update_options, states): 

901 mapper = update_options._subject_mapper 

902 eval_condition = update_options._eval_condition 

903 

904 raw_data = [ 

905 (state.obj(), state, state.dict) 

906 for state in states 

907 if state.mapper.isa(mapper) and not state.expired 

908 ] 

909 

910 identity_token = update_options._identity_token 

911 if identity_token is not None: 

912 raw_data = [ 

913 (obj, state, dict_) 

914 for obj, state, dict_ in raw_data 

915 if state.identity_token == identity_token 

916 ] 

917 

918 result = [] 

919 for obj, state, dict_ in raw_data: 

920 evaled_condition = eval_condition(obj) 

921 

922 # caution: don't use "in ()" or == here, _EXPIRE_OBJECT 

923 # evaluates as True for all comparisons 

924 if ( 

925 evaled_condition is True 

926 or evaled_condition is evaluator._EXPIRED_OBJECT 

927 ): 

928 result.append( 

929 ( 

930 obj, 

931 state, 

932 dict_, 

933 evaled_condition is evaluator._EXPIRED_OBJECT, 

934 ) 

935 ) 

936 return result 

937 

938 @classmethod 

939 def _eval_condition_from_statement(cls, update_options, statement): 

940 mapper = update_options._subject_mapper 

941 target_cls = mapper.class_ 

942 

943 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) 

944 crit = () 

945 if statement._where_criteria: 

946 crit += statement._where_criteria 

947 

948 global_attributes = {} 

949 for opt in statement._with_options: 

950 if opt._is_criteria_option: 

951 opt.get_global_criteria(global_attributes) 

952 

953 if global_attributes: 

954 crit += cls._adjust_for_extra_criteria(global_attributes, mapper) 

955 

956 if crit: 

957 eval_condition = evaluator_compiler.process(*crit) 

958 else: 

959 # workaround for mypy https://github.com/python/mypy/issues/14027 

960 def _eval_condition(obj): 

961 return True 

962 

963 eval_condition = _eval_condition 

964 

965 return eval_condition 

966 

967 @classmethod 

968 def _do_pre_synchronize_auto( 

969 cls, 

970 session, 

971 statement, 

972 params, 

973 execution_options, 

974 bind_arguments, 

975 update_options, 

976 ): 

977 """setup auto sync strategy 

978 

979 

980 "auto" checks if we can use "evaluate" first, then falls back 

981 to "fetch" 

982 

983 evaluate is vastly more efficient for the common case 

984 where session is empty, only has a few objects, and the UPDATE 

985 statement can potentially match thousands/millions of rows. 

986 

987 OTOH more complex criteria that fails to work with "evaluate" 

988 we would hope usually correlates with fewer net rows. 

989 

990 """ 

991 

992 try: 

993 eval_condition = cls._eval_condition_from_statement( 

994 update_options, statement 

995 ) 

996 

997 except evaluator.UnevaluatableError: 

998 pass 

999 else: 

1000 return update_options + { 

1001 "_eval_condition": eval_condition, 

1002 "_synchronize_session": "evaluate", 

1003 } 

1004 

1005 update_options += {"_synchronize_session": "fetch"} 

1006 return cls._do_pre_synchronize_fetch( 

1007 session, 

1008 statement, 

1009 params, 

1010 execution_options, 

1011 bind_arguments, 

1012 update_options, 

1013 ) 

1014 

1015 @classmethod 

1016 def _do_pre_synchronize_evaluate( 

1017 cls, 

1018 session, 

1019 statement, 

1020 params, 

1021 execution_options, 

1022 bind_arguments, 

1023 update_options, 

1024 ): 

1025 try: 

1026 eval_condition = cls._eval_condition_from_statement( 

1027 update_options, statement 

1028 ) 

1029 

1030 except evaluator.UnevaluatableError as err: 

1031 raise sa_exc.InvalidRequestError( 

1032 'Could not evaluate current criteria in Python: "%s". ' 

1033 "Specify 'fetch' or False for the " 

1034 "synchronize_session execution option." % err 

1035 ) from err 

1036 

1037 return update_options + { 

1038 "_eval_condition": eval_condition, 

1039 } 

1040 

1041 @classmethod 

1042 def _get_resolved_values(cls, mapper, statement): 

1043 if statement._multi_values: 

1044 return [] 

1045 elif statement._ordered_values: 

1046 return list(statement._ordered_values) 

1047 elif statement._values: 

1048 return list(statement._values.items()) 

1049 else: 

1050 return [] 

1051 

1052 @classmethod 

1053 def _resolved_keys_as_propnames(cls, mapper, resolved_values): 

1054 values = [] 

1055 for k, v in resolved_values: 

1056 if mapper and isinstance(k, expression.ColumnElement): 

1057 try: 

1058 attr = mapper._columntoproperty[k] 

1059 except orm_exc.UnmappedColumnError: 

1060 pass 

1061 else: 

1062 values.append((attr.key, v)) 

1063 else: 

1064 raise sa_exc.InvalidRequestError( 

1065 "Attribute name not found, can't be " 

1066 "synchronized back to objects: %r" % k 

1067 ) 

1068 return values 

1069 

1070 @classmethod 

1071 def _do_pre_synchronize_fetch( 

1072 cls, 

1073 session, 

1074 statement, 

1075 params, 

1076 execution_options, 

1077 bind_arguments, 

1078 update_options, 

1079 ): 

1080 mapper = update_options._subject_mapper 

1081 

1082 select_stmt = ( 

1083 select(*(mapper.primary_key + (mapper.select_identity_token,))) 

1084 .select_from(mapper) 

1085 .options(*statement._with_options) 

1086 ) 

1087 select_stmt._where_criteria = statement._where_criteria 

1088 

1089 # conditionally run the SELECT statement for pre-fetch, testing the 

1090 # "bind" for if we can use RETURNING or not using the do_orm_execute 

1091 # event. If RETURNING is available, the do_orm_execute event 

1092 # will cancel the SELECT from being actually run. 

1093 # 

1094 # The way this is organized seems strange, why don't we just 

1095 # call can_use_returning() before invoking the statement and get 

1096 # answer?, why does this go through the whole execute phase using an 

1097 # event? Answer: because we are integrating with extensions such 

1098 # as the horizontal sharding extention that "multiplexes" an individual 

1099 # statement run through multiple engines, and it uses 

1100 # do_orm_execute() to do that. 

1101 

1102 can_use_returning = None 

1103 

1104 def skip_for_returning(orm_context: ORMExecuteState) -> Any: 

1105 bind = orm_context.session.get_bind(**orm_context.bind_arguments) 

1106 nonlocal can_use_returning 

1107 

1108 per_bind_result = cls.can_use_returning( 

1109 bind.dialect, 

1110 mapper, 

1111 is_update_from=update_options._is_update_from, 

1112 is_delete_using=update_options._is_delete_using, 

1113 is_executemany=orm_context.is_executemany, 

1114 ) 

1115 

1116 if can_use_returning is not None: 

1117 if can_use_returning != per_bind_result: 

1118 raise sa_exc.InvalidRequestError( 

1119 "For synchronize_session='fetch', can't mix multiple " 

1120 "backends where some support RETURNING and others " 

1121 "don't" 

1122 ) 

1123 elif orm_context.is_executemany and not per_bind_result: 

1124 raise sa_exc.InvalidRequestError( 

1125 "For synchronize_session='fetch', can't use multiple " 

1126 "parameter sets in ORM mode, which this backend does not " 

1127 "support with RETURNING" 

1128 ) 

1129 else: 

1130 can_use_returning = per_bind_result 

1131 

1132 if per_bind_result: 

1133 return _result.null_result() 

1134 else: 

1135 return None 

1136 

1137 result = session.execute( 

1138 select_stmt, 

1139 params, 

1140 execution_options=execution_options, 

1141 bind_arguments=bind_arguments, 

1142 _add_event=skip_for_returning, 

1143 ) 

1144 matched_rows = result.fetchall() 

1145 

1146 return update_options + { 

1147 "_matched_rows": matched_rows, 

1148 "_can_use_returning": can_use_returning, 

1149 } 

1150 

1151 

1152@CompileState.plugin_for("orm", "insert") 

1153class BulkORMInsert(ORMDMLState, InsertDMLState): 

1154 class default_insert_options(Options): 

1155 _dml_strategy: DMLStrategyArgument = "auto" 

1156 _render_nulls: bool = False 

1157 _return_defaults: bool = False 

1158 _subject_mapper: Optional[Mapper[Any]] = None 

1159 _autoflush: bool = True 

1160 _populate_existing: bool = False 

1161 

1162 select_statement: Optional[FromStatement] = None 

1163 

1164 @classmethod 

1165 def orm_pre_session_exec( 

1166 cls, 

1167 session, 

1168 statement, 

1169 params, 

1170 execution_options, 

1171 bind_arguments, 

1172 is_pre_event, 

1173 ): 

1174 ( 

1175 insert_options, 

1176 execution_options, 

1177 ) = BulkORMInsert.default_insert_options.from_execution_options( 

1178 "_sa_orm_insert_options", 

1179 {"dml_strategy", "autoflush", "populate_existing", "render_nulls"}, 

1180 execution_options, 

1181 statement._execution_options, 

1182 ) 

1183 bind_arguments["clause"] = statement 

1184 try: 

1185 plugin_subject = statement._propagate_attrs["plugin_subject"] 

1186 except KeyError: 

1187 assert False, "statement had 'orm' plugin but no plugin_subject" 

1188 else: 

1189 if plugin_subject: 

1190 bind_arguments["mapper"] = plugin_subject.mapper 

1191 insert_options += {"_subject_mapper": plugin_subject.mapper} 

1192 

1193 if not params: 

1194 if insert_options._dml_strategy == "auto": 

1195 insert_options += {"_dml_strategy": "orm"} 

1196 elif insert_options._dml_strategy == "bulk": 

1197 raise sa_exc.InvalidRequestError( 

1198 'Can\'t use "bulk" ORM insert strategy without ' 

1199 "passing separate parameters" 

1200 ) 

1201 else: 

1202 if insert_options._dml_strategy == "auto": 

1203 insert_options += {"_dml_strategy": "bulk"} 

1204 

1205 if insert_options._dml_strategy != "raw": 

1206 # for ORM object loading, like ORMContext, we have to disable 

1207 # result set adapt_to_context, because we will be generating a 

1208 # new statement with specific columns that's cached inside of 

1209 # an ORMFromStatementCompileState, which we will re-use for 

1210 # each result. 

1211 if not execution_options: 

1212 execution_options = context._orm_load_exec_options 

1213 else: 

1214 execution_options = execution_options.union( 

1215 context._orm_load_exec_options 

1216 ) 

1217 

1218 if not is_pre_event and insert_options._autoflush: 

1219 session._autoflush() 

1220 

1221 statement = statement._annotate( 

1222 {"dml_strategy": insert_options._dml_strategy} 

1223 ) 

1224 

1225 return ( 

1226 statement, 

1227 util.immutabledict(execution_options).union( 

1228 {"_sa_orm_insert_options": insert_options} 

1229 ), 

1230 ) 

1231 

1232 @classmethod 

1233 def orm_execute_statement( 

1234 cls, 

1235 session: Session, 

1236 statement: dml.Insert, 

1237 params: _CoreAnyExecuteParams, 

1238 execution_options: OrmExecuteOptionsParameter, 

1239 bind_arguments: _BindArguments, 

1240 conn: Connection, 

1241 ) -> _result.Result: 

1242 insert_options = execution_options.get( 

1243 "_sa_orm_insert_options", cls.default_insert_options 

1244 ) 

1245 

1246 if insert_options._dml_strategy not in ( 

1247 "raw", 

1248 "bulk", 

1249 "orm", 

1250 "auto", 

1251 ): 

1252 raise sa_exc.ArgumentError( 

1253 "Valid strategies for ORM insert strategy " 

1254 "are 'raw', 'orm', 'bulk', 'auto" 

1255 ) 

1256 

1257 result: _result.Result[Any] 

1258 

1259 if insert_options._dml_strategy == "raw": 

1260 result = conn.execute( 

1261 statement, params or {}, execution_options=execution_options 

1262 ) 

1263 return result 

1264 

1265 if insert_options._dml_strategy == "bulk": 

1266 mapper = insert_options._subject_mapper 

1267 

1268 if ( 

1269 statement._post_values_clause is not None 

1270 and mapper._multiple_persistence_tables 

1271 ): 

1272 raise sa_exc.InvalidRequestError( 

1273 "bulk INSERT with a 'post values' clause " 

1274 "(typically upsert) not supported for multi-table " 

1275 f"mapper {mapper}" 

1276 ) 

1277 

1278 assert mapper is not None 

1279 assert session._transaction is not None 

1280 result = _bulk_insert( 

1281 mapper, 

1282 cast( 

1283 "Iterable[Dict[str, Any]]", 

1284 [params] if isinstance(params, dict) else params, 

1285 ), 

1286 session._transaction, 

1287 isstates=False, 

1288 return_defaults=insert_options._return_defaults, 

1289 render_nulls=insert_options._render_nulls, 

1290 use_orm_insert_stmt=statement, 

1291 execution_options=execution_options, 

1292 ) 

1293 elif insert_options._dml_strategy == "orm": 

1294 result = conn.execute( 

1295 statement, params or {}, execution_options=execution_options 

1296 ) 

1297 else: 

1298 raise AssertionError() 

1299 

1300 if not bool(statement._returning): 

1301 return result 

1302 

1303 if insert_options._populate_existing: 

1304 load_options = execution_options.get( 

1305 "_sa_orm_load_options", QueryContext.default_load_options 

1306 ) 

1307 load_options += {"_populate_existing": True} 

1308 execution_options = execution_options.union( 

1309 {"_sa_orm_load_options": load_options} 

1310 ) 

1311 

1312 return cls._return_orm_returning( 

1313 session, 

1314 statement, 

1315 params, 

1316 execution_options, 

1317 bind_arguments, 

1318 result, 

1319 ) 

1320 

1321 @classmethod 

1322 def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert: 

1323 self = cast( 

1324 BulkORMInsert, 

1325 super().create_for_statement(statement, compiler, **kw), 

1326 ) 

1327 

1328 if compiler is not None: 

1329 toplevel = not compiler.stack 

1330 else: 

1331 toplevel = True 

1332 if not toplevel: 

1333 return self 

1334 

1335 mapper = statement._propagate_attrs["plugin_subject"] 

1336 dml_strategy = statement._annotations.get("dml_strategy", "raw") 

1337 if dml_strategy == "bulk": 

1338 self._setup_for_bulk_insert(compiler) 

1339 elif dml_strategy == "orm": 

1340 self._setup_for_orm_insert(compiler, mapper) 

1341 

1342 return self 

1343 

1344 @classmethod 

1345 def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict): 

1346 return { 

1347 col.key if col is not None else k: v 

1348 for col, k, v in ( 

1349 (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items() 

1350 ) 

1351 } 

1352 

1353 def _setup_for_orm_insert(self, compiler, mapper): 

1354 statement = orm_level_statement = cast(dml.Insert, self.statement) 

1355 

1356 statement = self._setup_orm_returning( 

1357 compiler, 

1358 orm_level_statement, 

1359 statement, 

1360 dml_mapper=mapper, 

1361 use_supplemental_cols=False, 

1362 ) 

1363 self.statement = statement 

1364 

1365 def _setup_for_bulk_insert(self, compiler): 

1366 """establish an INSERT statement within the context of 

1367 bulk insert. 

1368 

1369 This method will be within the "conn.execute()" call that is invoked 

1370 by persistence._emit_insert_statement(). 

1371 

1372 """ 

1373 statement = orm_level_statement = cast(dml.Insert, self.statement) 

1374 an = statement._annotations 

1375 

1376 emit_insert_table, emit_insert_mapper = ( 

1377 an["_emit_insert_table"], 

1378 an["_emit_insert_mapper"], 

1379 ) 

1380 

1381 statement = statement._clone() 

1382 

1383 statement.table = emit_insert_table 

1384 if self._dict_parameters: 

1385 self._dict_parameters = { 

1386 col: val 

1387 for col, val in self._dict_parameters.items() 

1388 if col.table is emit_insert_table 

1389 } 

1390 

1391 statement = self._setup_orm_returning( 

1392 compiler, 

1393 orm_level_statement, 

1394 statement, 

1395 dml_mapper=emit_insert_mapper, 

1396 use_supplemental_cols=True, 

1397 ) 

1398 

1399 if ( 

1400 self.from_statement_ctx is not None 

1401 and self.from_statement_ctx.compile_options._is_star 

1402 ): 

1403 raise sa_exc.CompileError( 

1404 "Can't use RETURNING * with bulk ORM INSERT. " 

1405 "Please use a different INSERT form, such as INSERT..VALUES " 

1406 "or INSERT with a Core Connection" 

1407 ) 

1408 

1409 self.statement = statement 

1410 

1411 

1412@CompileState.plugin_for("orm", "update") 

1413class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): 

1414 @classmethod 

1415 def create_for_statement(cls, statement, compiler, **kw): 

1416 self = cls.__new__(cls) 

1417 

1418 dml_strategy = statement._annotations.get( 

1419 "dml_strategy", "unspecified" 

1420 ) 

1421 

1422 toplevel = not compiler.stack 

1423 

1424 if toplevel and dml_strategy == "bulk": 

1425 self._setup_for_bulk_update(statement, compiler) 

1426 elif ( 

1427 dml_strategy == "core_only" 

1428 or dml_strategy == "unspecified" 

1429 and "parententity" not in statement.table._annotations 

1430 ): 

1431 UpdateDMLState.__init__(self, statement, compiler, **kw) 

1432 elif not toplevel or dml_strategy in ("orm", "unspecified"): 

1433 self._setup_for_orm_update(statement, compiler) 

1434 

1435 return self 

1436 

1437 def _setup_for_orm_update(self, statement, compiler, **kw): 

1438 orm_level_statement = statement 

1439 

1440 toplevel = not compiler.stack 

1441 

1442 ext_info = statement.table._annotations["parententity"] 

1443 

1444 self.mapper = mapper = ext_info.mapper 

1445 

1446 self._resolved_values = self._get_resolved_values(mapper, statement) 

1447 

1448 self._init_global_attributes( 

1449 statement, 

1450 compiler, 

1451 toplevel=toplevel, 

1452 process_criteria_for_toplevel=toplevel, 

1453 ) 

1454 

1455 if statement._values: 

1456 self._resolved_values = dict(self._resolved_values) 

1457 

1458 new_stmt = statement._clone() 

1459 

1460 if new_stmt.table._annotations["parententity"] is mapper: 

1461 new_stmt.table = mapper.local_table 

1462 

1463 # note if the statement has _multi_values, these 

1464 # are passed through to the new statement, which will then raise 

1465 # InvalidRequestError because UPDATE doesn't support multi_values 

1466 # right now. 

1467 if statement._ordered_values: 

1468 new_stmt._ordered_values = self._resolved_values 

1469 elif statement._values: 

1470 new_stmt._values = self._resolved_values 

1471 

1472 new_crit = self._adjust_for_extra_criteria( 

1473 self.global_attributes, mapper 

1474 ) 

1475 if new_crit: 

1476 new_stmt = new_stmt.where(*new_crit) 

1477 

1478 # if we are against a lambda statement we might not be the 

1479 # topmost object that received per-execute annotations 

1480 

1481 # do this first as we need to determine if there is 

1482 # UPDATE..FROM 

1483 

1484 UpdateDMLState.__init__(self, new_stmt, compiler, **kw) 

1485 

1486 use_supplemental_cols = False 

1487 

1488 if not toplevel: 

1489 synchronize_session = None 

1490 else: 

1491 synchronize_session = compiler._annotations.get( 

1492 "synchronize_session", None 

1493 ) 

1494 can_use_returning = compiler._annotations.get( 

1495 "can_use_returning", None 

1496 ) 

1497 if can_use_returning is not False: 

1498 # even though pre_exec has determined basic 

1499 # can_use_returning for the dialect, if we are to use 

1500 # RETURNING we need to run can_use_returning() at this level 

1501 # unconditionally because is_delete_using was not known 

1502 # at the pre_exec level 

1503 can_use_returning = ( 

1504 synchronize_session == "fetch" 

1505 and self.can_use_returning( 

1506 compiler.dialect, mapper, is_multitable=self.is_multitable 

1507 ) 

1508 ) 

1509 

1510 if synchronize_session == "fetch" and can_use_returning: 

1511 use_supplemental_cols = True 

1512 

1513 # NOTE: we might want to RETURNING the actual columns to be 

1514 # synchronized also. however this is complicated and difficult 

1515 # to align against the behavior of "evaluate". Additionally, 

1516 # in a large number (if not the majority) of cases, we have the 

1517 # "evaluate" answer, usually a fixed value, in memory already and 

1518 # there's no need to re-fetch the same value 

1519 # over and over again. so perhaps if it could be RETURNING just 

1520 # the elements that were based on a SQL expression and not 

1521 # a constant. For now it doesn't quite seem worth it 

1522 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) 

1523 

1524 if toplevel: 

1525 new_stmt = self._setup_orm_returning( 

1526 compiler, 

1527 orm_level_statement, 

1528 new_stmt, 

1529 dml_mapper=mapper, 

1530 use_supplemental_cols=use_supplemental_cols, 

1531 ) 

1532 

1533 self.statement = new_stmt 

1534 

1535 def _setup_for_bulk_update(self, statement, compiler, **kw): 

1536 """establish an UPDATE statement within the context of 

1537 bulk insert. 

1538 

1539 This method will be within the "conn.execute()" call that is invoked 

1540 by persistence._emit_update_statement(). 

1541 

1542 """ 

1543 statement = cast(dml.Update, statement) 

1544 an = statement._annotations 

1545 

1546 emit_update_table, _ = ( 

1547 an["_emit_update_table"], 

1548 an["_emit_update_mapper"], 

1549 ) 

1550 

1551 statement = statement._clone() 

1552 statement.table = emit_update_table 

1553 

1554 UpdateDMLState.__init__(self, statement, compiler, **kw) 

1555 

1556 if self._ordered_values: 

1557 raise sa_exc.InvalidRequestError( 

1558 "bulk ORM UPDATE does not support ordered_values() for " 

1559 "custom UPDATE statements with bulk parameter sets. Use a " 

1560 "non-bulk UPDATE statement or use values()." 

1561 ) 

1562 

1563 if self._dict_parameters: 

1564 self._dict_parameters = { 

1565 col: val 

1566 for col, val in self._dict_parameters.items() 

1567 if col.table is emit_update_table 

1568 } 

1569 self.statement = statement 

1570 

1571 @classmethod 

1572 def orm_execute_statement( 

1573 cls, 

1574 session: Session, 

1575 statement: dml.Update, 

1576 params: _CoreAnyExecuteParams, 

1577 execution_options: OrmExecuteOptionsParameter, 

1578 bind_arguments: _BindArguments, 

1579 conn: Connection, 

1580 ) -> _result.Result: 

1581 

1582 update_options = execution_options.get( 

1583 "_sa_orm_update_options", cls.default_update_options 

1584 ) 

1585 

1586 if update_options._populate_existing: 

1587 load_options = execution_options.get( 

1588 "_sa_orm_load_options", QueryContext.default_load_options 

1589 ) 

1590 load_options += {"_populate_existing": True} 

1591 execution_options = execution_options.union( 

1592 {"_sa_orm_load_options": load_options} 

1593 ) 

1594 

1595 if update_options._dml_strategy not in ( 

1596 "orm", 

1597 "auto", 

1598 "bulk", 

1599 "core_only", 

1600 ): 

1601 raise sa_exc.ArgumentError( 

1602 "Valid strategies for ORM UPDATE strategy " 

1603 "are 'orm', 'auto', 'bulk', 'core_only'" 

1604 ) 

1605 

1606 result: _result.Result[Any] 

1607 

1608 if update_options._dml_strategy == "bulk": 

1609 enable_check_rowcount = not statement._where_criteria 

1610 

1611 assert update_options._synchronize_session != "fetch" 

1612 

1613 if ( 

1614 statement._where_criteria 

1615 and update_options._synchronize_session == "evaluate" 

1616 ): 

1617 raise sa_exc.InvalidRequestError( 

1618 "bulk synchronize of persistent objects not supported " 

1619 "when using bulk update with additional WHERE " 

1620 "criteria right now. add synchronize_session=None " 

1621 "execution option to bypass synchronize of persistent " 

1622 "objects." 

1623 ) 

1624 mapper = update_options._subject_mapper 

1625 assert mapper is not None 

1626 assert session._transaction is not None 

1627 result = _bulk_update( 

1628 mapper, 

1629 cast( 

1630 "Iterable[Dict[str, Any]]", 

1631 [params] if isinstance(params, dict) else params, 

1632 ), 

1633 session._transaction, 

1634 isstates=False, 

1635 update_changed_only=False, 

1636 use_orm_update_stmt=statement, 

1637 enable_check_rowcount=enable_check_rowcount, 

1638 ) 

1639 return cls.orm_setup_cursor_result( 

1640 session, 

1641 statement, 

1642 params, 

1643 execution_options, 

1644 bind_arguments, 

1645 result, 

1646 ) 

1647 else: 

1648 return super().orm_execute_statement( 

1649 session, 

1650 statement, 

1651 params, 

1652 execution_options, 

1653 bind_arguments, 

1654 conn, 

1655 ) 

1656 

1657 @classmethod 

1658 def can_use_returning( 

1659 cls, 

1660 dialect: Dialect, 

1661 mapper: Mapper[Any], 

1662 *, 

1663 is_multitable: bool = False, 

1664 is_update_from: bool = False, 

1665 is_delete_using: bool = False, 

1666 is_executemany: bool = False, 

1667 ) -> bool: 

1668 # normal answer for "should we use RETURNING" at all. 

1669 normal_answer = ( 

1670 dialect.update_returning and mapper.local_table.implicit_returning 

1671 ) 

1672 if not normal_answer: 

1673 return False 

1674 

1675 if is_executemany: 

1676 return dialect.update_executemany_returning 

1677 

1678 # these workarounds are currently hypothetical for UPDATE, 

1679 # unlike DELETE where they impact MariaDB 

1680 if is_update_from: 

1681 return dialect.update_returning_multifrom 

1682 

1683 elif is_multitable and not dialect.update_returning_multifrom: 

1684 raise sa_exc.CompileError( 

1685 f'Dialect "{dialect.name}" does not support RETURNING ' 

1686 "with UPDATE..FROM; for synchronize_session='fetch', " 

1687 "please add the additional execution option " 

1688 "'is_update_from=True' to the statement to indicate that " 

1689 "a separate SELECT should be used for this backend." 

1690 ) 

1691 

1692 return True 

1693 

1694 @classmethod 

1695 def _do_post_synchronize_bulk_evaluate( 

1696 cls, session, params, result, update_options 

1697 ): 

1698 if not params: 

1699 return 

1700 

1701 mapper = update_options._subject_mapper 

1702 pk_keys = [prop.key for prop in mapper._identity_key_props] 

1703 

1704 identity_map = session.identity_map 

1705 

1706 for param in params: 

1707 identity_key = mapper.identity_key_from_primary_key( 

1708 (param[key] for key in pk_keys), 

1709 update_options._identity_token, 

1710 ) 

1711 state = identity_map.fast_get_state(identity_key) 

1712 if not state: 

1713 continue 

1714 

1715 evaluated_keys = set(param).difference(pk_keys) 

1716 

1717 dict_ = state.dict 

1718 # only evaluate unmodified attributes 

1719 to_evaluate = state.unmodified.intersection(evaluated_keys) 

1720 for key in to_evaluate: 

1721 if key in dict_: 

1722 dict_[key] = param[key] 

1723 

1724 state.manager.dispatch.refresh(state, None, to_evaluate) 

1725 

1726 state._commit(dict_, list(to_evaluate)) 

1727 

1728 # attributes that were formerly modified instead get expired. 

1729 # this only gets hit if the session had pending changes 

1730 # and autoflush were set to False. 

1731 to_expire = evaluated_keys.intersection(dict_).difference( 

1732 to_evaluate 

1733 ) 

1734 if to_expire: 

1735 state._expire_attributes(dict_, to_expire) 

1736 

1737 @classmethod 

1738 def _do_post_synchronize_evaluate( 

1739 cls, session, statement, result, update_options 

1740 ): 

1741 matched_objects = cls._get_matched_objects_on_criteria( 

1742 update_options, 

1743 session.identity_map.all_states(), 

1744 ) 

1745 

1746 cls._apply_update_set_values_to_objects( 

1747 session, 

1748 update_options, 

1749 statement, 

1750 result.context.compiled_parameters[0], 

1751 [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], 

1752 result.prefetch_cols(), 

1753 result.postfetch_cols(), 

1754 ) 

1755 

1756 @classmethod 

1757 def _do_post_synchronize_fetch( 

1758 cls, session, statement, result, update_options 

1759 ): 

1760 target_mapper = update_options._subject_mapper 

1761 

1762 returned_defaults_rows = result.returned_defaults_rows 

1763 if returned_defaults_rows: 

1764 pk_rows = cls._interpret_returning_rows( 

1765 result, target_mapper, returned_defaults_rows 

1766 ) 

1767 matched_rows = [ 

1768 tuple(row) + (update_options._identity_token,) 

1769 for row in pk_rows 

1770 ] 

1771 else: 

1772 matched_rows = update_options._matched_rows 

1773 

1774 objs = [ 

1775 session.identity_map[identity_key] 

1776 for identity_key in [ 

1777 target_mapper.identity_key_from_primary_key( 

1778 list(primary_key), 

1779 identity_token=identity_token, 

1780 ) 

1781 for primary_key, identity_token in [ 

1782 (row[0:-1], row[-1]) for row in matched_rows 

1783 ] 

1784 if update_options._identity_token is None 

1785 or identity_token == update_options._identity_token 

1786 ] 

1787 if identity_key in session.identity_map 

1788 ] 

1789 

1790 if not objs: 

1791 return 

1792 

1793 cls._apply_update_set_values_to_objects( 

1794 session, 

1795 update_options, 

1796 statement, 

1797 result.context.compiled_parameters[0], 

1798 [ 

1799 ( 

1800 obj, 

1801 attributes.instance_state(obj), 

1802 attributes.instance_dict(obj), 

1803 ) 

1804 for obj in objs 

1805 ], 

1806 result.prefetch_cols(), 

1807 result.postfetch_cols(), 

1808 ) 

1809 

1810 @classmethod 

1811 def _apply_update_set_values_to_objects( 

1812 cls, 

1813 session, 

1814 update_options, 

1815 statement, 

1816 effective_params, 

1817 matched_objects, 

1818 prefetch_cols, 

1819 postfetch_cols, 

1820 ): 

1821 """apply values to objects derived from an update statement, e.g. 

1822 UPDATE..SET <values> 

1823 

1824 """ 

1825 

1826 mapper = update_options._subject_mapper 

1827 target_cls = mapper.class_ 

1828 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) 

1829 resolved_values = cls._get_resolved_values(mapper, statement) 

1830 resolved_keys_as_propnames = cls._resolved_keys_as_propnames( 

1831 mapper, resolved_values 

1832 ) 

1833 value_evaluators = {} 

1834 for key, value in resolved_keys_as_propnames: 

1835 try: 

1836 _evaluator = evaluator_compiler.process( 

1837 coercions.expect(roles.ExpressionElementRole, value) 

1838 ) 

1839 except evaluator.UnevaluatableError: 

1840 pass 

1841 else: 

1842 value_evaluators[key] = _evaluator 

1843 

1844 evaluated_keys = list(value_evaluators.keys()) 

1845 attrib = {k for k, v in resolved_keys_as_propnames} 

1846 

1847 states = set() 

1848 

1849 to_prefetch = { 

1850 c 

1851 for c in prefetch_cols 

1852 if c.key in effective_params 

1853 and c in mapper._columntoproperty 

1854 and c.key not in evaluated_keys 

1855 } 

1856 to_expire = { 

1857 mapper._columntoproperty[c].key 

1858 for c in postfetch_cols 

1859 if c in mapper._columntoproperty 

1860 }.difference(evaluated_keys) 

1861 

1862 prefetch_transfer = [ 

1863 (mapper._columntoproperty[c].key, c.key) for c in to_prefetch 

1864 ] 

1865 

1866 for obj, state, dict_ in matched_objects: 

1867 

1868 dict_.update( 

1869 { 

1870 col_to_prop: effective_params[c_key] 

1871 for col_to_prop, c_key in prefetch_transfer 

1872 } 

1873 ) 

1874 

1875 state._expire_attributes(state.dict, to_expire) 

1876 

1877 to_evaluate = state.unmodified.intersection(evaluated_keys) 

1878 

1879 for key in to_evaluate: 

1880 if key in dict_: 

1881 # only run eval for attributes that are present. 

1882 dict_[key] = value_evaluators[key](obj) 

1883 

1884 state.manager.dispatch.refresh(state, None, to_evaluate) 

1885 

1886 state._commit(dict_, list(to_evaluate)) 

1887 

1888 # attributes that were formerly modified instead get expired. 

1889 # this only gets hit if the session had pending changes 

1890 # and autoflush were set to False. 

1891 to_expire = attrib.intersection(dict_).difference(to_evaluate) 

1892 if to_expire: 

1893 state._expire_attributes(dict_, to_expire) 

1894 

1895 states.add(state) 

1896 session._register_altered(states) 

1897 

1898 

1899@CompileState.plugin_for("orm", "delete") 

1900class BulkORMDelete(BulkUDCompileState, DeleteDMLState): 

1901 @classmethod 

1902 def create_for_statement(cls, statement, compiler, **kw): 

1903 self = cls.__new__(cls) 

1904 

1905 dml_strategy = statement._annotations.get( 

1906 "dml_strategy", "unspecified" 

1907 ) 

1908 

1909 if ( 

1910 dml_strategy == "core_only" 

1911 or dml_strategy == "unspecified" 

1912 and "parententity" not in statement.table._annotations 

1913 ): 

1914 DeleteDMLState.__init__(self, statement, compiler, **kw) 

1915 return self 

1916 

1917 toplevel = not compiler.stack 

1918 

1919 orm_level_statement = statement 

1920 

1921 ext_info = statement.table._annotations["parententity"] 

1922 self.mapper = mapper = ext_info.mapper 

1923 

1924 self._init_global_attributes( 

1925 statement, 

1926 compiler, 

1927 toplevel=toplevel, 

1928 process_criteria_for_toplevel=toplevel, 

1929 ) 

1930 

1931 new_stmt = statement._clone() 

1932 

1933 if new_stmt.table._annotations["parententity"] is mapper: 

1934 new_stmt.table = mapper.local_table 

1935 

1936 new_crit = cls._adjust_for_extra_criteria( 

1937 self.global_attributes, mapper 

1938 ) 

1939 if new_crit: 

1940 new_stmt = new_stmt.where(*new_crit) 

1941 

1942 # do this first as we need to determine if there is 

1943 # DELETE..FROM 

1944 DeleteDMLState.__init__(self, new_stmt, compiler, **kw) 

1945 

1946 use_supplemental_cols = False 

1947 

1948 if not toplevel: 

1949 synchronize_session = None 

1950 else: 

1951 synchronize_session = compiler._annotations.get( 

1952 "synchronize_session", None 

1953 ) 

1954 can_use_returning = compiler._annotations.get( 

1955 "can_use_returning", None 

1956 ) 

1957 if can_use_returning is not False: 

1958 # even though pre_exec has determined basic 

1959 # can_use_returning for the dialect, if we are to use 

1960 # RETURNING we need to run can_use_returning() at this level 

1961 # unconditionally because is_delete_using was not known 

1962 # at the pre_exec level 

1963 can_use_returning = ( 

1964 synchronize_session == "fetch" 

1965 and self.can_use_returning( 

1966 compiler.dialect, 

1967 mapper, 

1968 is_multitable=self.is_multitable, 

1969 is_delete_using=compiler._annotations.get( 

1970 "is_delete_using", False 

1971 ), 

1972 ) 

1973 ) 

1974 

1975 if can_use_returning: 

1976 use_supplemental_cols = True 

1977 

1978 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) 

1979 

1980 if toplevel: 

1981 new_stmt = self._setup_orm_returning( 

1982 compiler, 

1983 orm_level_statement, 

1984 new_stmt, 

1985 dml_mapper=mapper, 

1986 use_supplemental_cols=use_supplemental_cols, 

1987 ) 

1988 

1989 self.statement = new_stmt 

1990 

1991 return self 

1992 

1993 @classmethod 

1994 def orm_execute_statement( 

1995 cls, 

1996 session: Session, 

1997 statement: dml.Delete, 

1998 params: _CoreAnyExecuteParams, 

1999 execution_options: OrmExecuteOptionsParameter, 

2000 bind_arguments: _BindArguments, 

2001 conn: Connection, 

2002 ) -> _result.Result: 

2003 update_options = execution_options.get( 

2004 "_sa_orm_update_options", cls.default_update_options 

2005 ) 

2006 

2007 if update_options._dml_strategy == "bulk": 

2008 raise sa_exc.InvalidRequestError( 

2009 "Bulk ORM DELETE not supported right now. " 

2010 "Statement may be invoked at the " 

2011 "Core level using " 

2012 "session.connection().execute(stmt, parameters)" 

2013 ) 

2014 

2015 if update_options._dml_strategy not in ("orm", "auto", "core_only"): 

2016 raise sa_exc.ArgumentError( 

2017 "Valid strategies for ORM DELETE strategy are 'orm', 'auto', " 

2018 "'core_only'" 

2019 ) 

2020 

2021 return super().orm_execute_statement( 

2022 session, statement, params, execution_options, bind_arguments, conn 

2023 ) 

2024 

2025 @classmethod 

2026 def can_use_returning( 

2027 cls, 

2028 dialect: Dialect, 

2029 mapper: Mapper[Any], 

2030 *, 

2031 is_multitable: bool = False, 

2032 is_update_from: bool = False, 

2033 is_delete_using: bool = False, 

2034 is_executemany: bool = False, 

2035 ) -> bool: 

2036 # normal answer for "should we use RETURNING" at all. 

2037 normal_answer = ( 

2038 dialect.delete_returning and mapper.local_table.implicit_returning 

2039 ) 

2040 if not normal_answer: 

2041 return False 

2042 

2043 # now get into special workarounds because MariaDB supports 

2044 # DELETE...RETURNING but not DELETE...USING...RETURNING. 

2045 if is_delete_using: 

2046 # is_delete_using hint was passed. use 

2047 # additional dialect feature (True for PG, False for MariaDB) 

2048 return dialect.delete_returning_multifrom 

2049 

2050 elif is_multitable and not dialect.delete_returning_multifrom: 

2051 # is_delete_using hint was not passed, but we determined 

2052 # at compile time that this is in fact a DELETE..USING. 

2053 # it's too late to continue since we did not pre-SELECT. 

2054 # raise that we need that hint up front. 

2055 

2056 raise sa_exc.CompileError( 

2057 f'Dialect "{dialect.name}" does not support RETURNING ' 

2058 "with DELETE..USING; for synchronize_session='fetch', " 

2059 "please add the additional execution option " 

2060 "'is_delete_using=True' to the statement to indicate that " 

2061 "a separate SELECT should be used for this backend." 

2062 ) 

2063 

2064 return True 

2065 

2066 @classmethod 

2067 def _do_post_synchronize_evaluate( 

2068 cls, session, statement, result, update_options 

2069 ): 

2070 matched_objects = cls._get_matched_objects_on_criteria( 

2071 update_options, 

2072 session.identity_map.all_states(), 

2073 ) 

2074 

2075 to_delete = [] 

2076 

2077 for _, state, dict_, is_partially_expired in matched_objects: 

2078 if is_partially_expired: 

2079 state._expire(dict_, session.identity_map._modified) 

2080 else: 

2081 to_delete.append(state) 

2082 

2083 if to_delete: 

2084 session._remove_newly_deleted(to_delete) 

2085 

2086 @classmethod 

2087 def _do_post_synchronize_fetch( 

2088 cls, session, statement, result, update_options 

2089 ): 

2090 target_mapper = update_options._subject_mapper 

2091 

2092 returned_defaults_rows = result.returned_defaults_rows 

2093 

2094 if returned_defaults_rows: 

2095 pk_rows = cls._interpret_returning_rows( 

2096 result, target_mapper, returned_defaults_rows 

2097 ) 

2098 

2099 matched_rows = [ 

2100 tuple(row) + (update_options._identity_token,) 

2101 for row in pk_rows 

2102 ] 

2103 else: 

2104 matched_rows = update_options._matched_rows 

2105 

2106 for row in matched_rows: 

2107 primary_key = row[0:-1] 

2108 identity_token = row[-1] 

2109 

2110 # TODO: inline this and call remove_newly_deleted 

2111 # once 

2112 identity_key = target_mapper.identity_key_from_primary_key( 

2113 list(primary_key), 

2114 identity_token=identity_token, 

2115 ) 

2116 if identity_key in session.identity_map: 

2117 session._remove_newly_deleted( 

2118 [ 

2119 attributes.instance_state( 

2120 session.identity_map[identity_key] 

2121 ) 

2122 ] 

2123 )