Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/sqlalchemy/orm/bulk_persistence.py: 22%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

724 statements  

1# orm/bulk_persistence.py 

2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7# mypy: ignore-errors 

8 

9 

10"""additional ORM persistence classes related to "bulk" operations, 

11specifically outside of the flush() process. 

12 

13""" 

14 

15from __future__ import annotations 

16 

17from typing import Any 

18from typing import cast 

19from typing import Dict 

20from typing import Iterable 

21from typing import Optional 

22from typing import overload 

23from typing import TYPE_CHECKING 

24from typing import TypeVar 

25from typing import Union 

26 

27from . import attributes 

28from . import context 

29from . import evaluator 

30from . import exc as orm_exc 

31from . import loading 

32from . import persistence 

33from .base import NO_VALUE 

34from .context import AbstractORMCompileState 

35from .context import FromStatement 

36from .context import ORMFromStatementCompileState 

37from .context import QueryContext 

38from .. import exc as sa_exc 

39from .. import util 

40from ..engine import Dialect 

41from ..engine import result as _result 

42from ..sql import coercions 

43from ..sql import dml 

44from ..sql import expression 

45from ..sql import roles 

46from ..sql import select 

47from ..sql import sqltypes 

48from ..sql.base import _entity_namespace_key 

49from ..sql.base import CompileState 

50from ..sql.base import Options 

51from ..sql.dml import DeleteDMLState 

52from ..sql.dml import InsertDMLState 

53from ..sql.dml import UpdateDMLState 

54from ..util import EMPTY_DICT 

55from ..util.typing import Literal 

56from ..util.typing import TupleAny 

57from ..util.typing import Unpack 

58 

59if TYPE_CHECKING: 

60 from ._typing import DMLStrategyArgument 

61 from ._typing import OrmExecuteOptionsParameter 

62 from ._typing import SynchronizeSessionArgument 

63 from .mapper import Mapper 

64 from .session import _BindArguments 

65 from .session import ORMExecuteState 

66 from .session import Session 

67 from .session import SessionTransaction 

68 from .state import InstanceState 

69 from ..engine import Connection 

70 from ..engine import cursor 

71 from ..engine.interfaces import _CoreAnyExecuteParams 

72 

73_O = TypeVar("_O", bound=object) 

74 

75 

76@overload 

77def _bulk_insert( 

78 mapper: Mapper[_O], 

79 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

80 session_transaction: SessionTransaction, 

81 *, 

82 isstates: bool, 

83 return_defaults: bool, 

84 render_nulls: bool, 

85 use_orm_insert_stmt: Literal[None] = ..., 

86 execution_options: Optional[OrmExecuteOptionsParameter] = ..., 

87) -> None: ... 

88 

89 

90@overload 

91def _bulk_insert( 

92 mapper: Mapper[_O], 

93 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

94 session_transaction: SessionTransaction, 

95 *, 

96 isstates: bool, 

97 return_defaults: bool, 

98 render_nulls: bool, 

99 use_orm_insert_stmt: Optional[dml.Insert] = ..., 

100 execution_options: Optional[OrmExecuteOptionsParameter] = ..., 

101) -> cursor.CursorResult[Any]: ... 

102 

103 

104def _bulk_insert( 

105 mapper: Mapper[_O], 

106 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

107 session_transaction: SessionTransaction, 

108 *, 

109 isstates: bool, 

110 return_defaults: bool, 

111 render_nulls: bool, 

112 use_orm_insert_stmt: Optional[dml.Insert] = None, 

113 execution_options: Optional[OrmExecuteOptionsParameter] = None, 

114) -> Optional[cursor.CursorResult[Any]]: 

115 base_mapper = mapper.base_mapper 

116 

117 if session_transaction.session.connection_callable: 

118 raise NotImplementedError( 

119 "connection_callable / per-instance sharding " 

120 "not supported in bulk_insert()" 

121 ) 

122 

123 if isstates: 

124 if return_defaults: 

125 states = [(state, state.dict) for state in mappings] 

126 mappings = [dict_ for (state, dict_) in states] 

127 else: 

128 mappings = [state.dict for state in mappings] 

129 else: 

130 mappings = [dict(m) for m in mappings] 

131 _expand_composites(mapper, mappings) 

132 

133 connection = session_transaction.connection(base_mapper) 

134 

135 return_result: Optional[cursor.CursorResult[Any]] = None 

136 

137 mappers_to_run = [ 

138 (table, mp) 

139 for table, mp in base_mapper._sorted_tables.items() 

140 if table in mapper._pks_by_table 

141 ] 

142 

143 if return_defaults: 

144 # not used by new-style bulk inserts, only used for legacy 

145 bookkeeping = True 

146 elif len(mappers_to_run) > 1: 

147 # if we have more than one table, mapper to run where we will be 

148 # either horizontally splicing, or copying values between tables, 

149 # we need the "bookkeeping" / deterministic returning order 

150 bookkeeping = True 

151 else: 

152 bookkeeping = False 

153 

154 for table, super_mapper in mappers_to_run: 

155 # find bindparams in the statement. For bulk, we don't really know if 

156 # a key in the params applies to a different table since we are 

157 # potentially inserting for multiple tables here; looking at the 

158 # bindparam() is a lot more direct. in most cases this will 

159 # use _generate_cache_key() which is memoized, although in practice 

160 # the ultimate statement that's executed is probably not the same 

161 # object so that memoization might not matter much. 

162 extra_bp_names = ( 

163 [ 

164 b.key 

165 for b in use_orm_insert_stmt._get_embedded_bindparams() 

166 if b.key in mappings[0] 

167 ] 

168 if use_orm_insert_stmt is not None 

169 else () 

170 ) 

171 

172 records = ( 

173 ( 

174 None, 

175 state_dict, 

176 params, 

177 mapper, 

178 connection, 

179 value_params, 

180 has_all_pks, 

181 has_all_defaults, 

182 ) 

183 for ( 

184 state, 

185 state_dict, 

186 params, 

187 mp, 

188 conn, 

189 value_params, 

190 has_all_pks, 

191 has_all_defaults, 

192 ) in persistence._collect_insert_commands( 

193 table, 

194 ((None, mapping, mapper, connection) for mapping in mappings), 

195 bulk=True, 

196 return_defaults=bookkeeping, 

197 render_nulls=render_nulls, 

198 include_bulk_keys=extra_bp_names, 

199 ) 

200 ) 

201 

202 result = persistence._emit_insert_statements( 

203 base_mapper, 

204 None, 

205 super_mapper, 

206 table, 

207 records, 

208 bookkeeping=bookkeeping, 

209 use_orm_insert_stmt=use_orm_insert_stmt, 

210 execution_options=execution_options, 

211 ) 

212 if use_orm_insert_stmt is not None: 

213 if not use_orm_insert_stmt._returning or return_result is None: 

214 return_result = result 

215 elif result.returns_rows: 

216 assert bookkeeping 

217 return_result = return_result.splice_horizontally(result) 

218 

219 if return_defaults and isstates: 

220 identity_cls = mapper._identity_class 

221 identity_props = [p.key for p in mapper._identity_key_props] 

222 for state, dict_ in states: 

223 state.key = ( 

224 identity_cls, 

225 tuple([dict_[key] for key in identity_props]), 

226 None, 

227 ) 

228 

229 if use_orm_insert_stmt is not None: 

230 assert return_result is not None 

231 return return_result 

232 

233 

234@overload 

235def _bulk_update( 

236 mapper: Mapper[Any], 

237 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

238 session_transaction: SessionTransaction, 

239 *, 

240 isstates: bool, 

241 update_changed_only: bool, 

242 use_orm_update_stmt: Literal[None] = ..., 

243 enable_check_rowcount: bool = True, 

244) -> None: ... 

245 

246 

247@overload 

248def _bulk_update( 

249 mapper: Mapper[Any], 

250 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

251 session_transaction: SessionTransaction, 

252 *, 

253 isstates: bool, 

254 update_changed_only: bool, 

255 use_orm_update_stmt: Optional[dml.Update] = ..., 

256 enable_check_rowcount: bool = True, 

257) -> _result.Result[Unpack[TupleAny]]: ... 

258 

259 

260def _bulk_update( 

261 mapper: Mapper[Any], 

262 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

263 session_transaction: SessionTransaction, 

264 *, 

265 isstates: bool, 

266 update_changed_only: bool, 

267 use_orm_update_stmt: Optional[dml.Update] = None, 

268 enable_check_rowcount: bool = True, 

269) -> Optional[_result.Result[Unpack[TupleAny]]]: 

270 base_mapper = mapper.base_mapper 

271 

272 search_keys = mapper._primary_key_propkeys 

273 if mapper._version_id_prop: 

274 search_keys = {mapper._version_id_prop.key}.union(search_keys) 

275 

276 def _changed_dict(mapper, state): 

277 return { 

278 k: v 

279 for k, v in state.dict.items() 

280 if k in state.committed_state or k in search_keys 

281 } 

282 

283 if isstates: 

284 if update_changed_only: 

285 mappings = [_changed_dict(mapper, state) for state in mappings] 

286 else: 

287 mappings = [state.dict for state in mappings] 

288 else: 

289 mappings = [dict(m) for m in mappings] 

290 _expand_composites(mapper, mappings) 

291 

292 if session_transaction.session.connection_callable: 

293 raise NotImplementedError( 

294 "connection_callable / per-instance sharding " 

295 "not supported in bulk_update()" 

296 ) 

297 

298 connection = session_transaction.connection(base_mapper) 

299 

300 # find bindparams in the statement. see _bulk_insert for similar 

301 # notes for the insert case 

302 extra_bp_names = ( 

303 [ 

304 b.key 

305 for b in use_orm_update_stmt._get_embedded_bindparams() 

306 if b.key in mappings[0] 

307 ] 

308 if use_orm_update_stmt is not None 

309 else () 

310 ) 

311 

312 for table, super_mapper in base_mapper._sorted_tables.items(): 

313 if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: 

314 continue 

315 

316 records = persistence._collect_update_commands( 

317 None, 

318 table, 

319 ( 

320 ( 

321 None, 

322 mapping, 

323 mapper, 

324 connection, 

325 ( 

326 mapping[mapper._version_id_prop.key] 

327 if mapper._version_id_prop 

328 else None 

329 ), 

330 ) 

331 for mapping in mappings 

332 ), 

333 bulk=True, 

334 use_orm_update_stmt=use_orm_update_stmt, 

335 include_bulk_keys=extra_bp_names, 

336 ) 

337 persistence._emit_update_statements( 

338 base_mapper, 

339 None, 

340 super_mapper, 

341 table, 

342 records, 

343 bookkeeping=False, 

344 use_orm_update_stmt=use_orm_update_stmt, 

345 enable_check_rowcount=enable_check_rowcount, 

346 ) 

347 

348 if use_orm_update_stmt is not None: 

349 return _result.null_result() 

350 

351 

352def _expand_composites(mapper, mappings): 

353 composite_attrs = mapper.composites 

354 if not composite_attrs: 

355 return 

356 

357 composite_keys = set(composite_attrs.keys()) 

358 populators = { 

359 key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() 

360 for key in composite_keys 

361 } 

362 for mapping in mappings: 

363 for key in composite_keys.intersection(mapping): 

364 populators[key](mapping) 

365 

366 

367class ORMDMLState(AbstractORMCompileState): 

368 is_dml_returning = True 

369 from_statement_ctx: Optional[ORMFromStatementCompileState] = None 

370 

371 @classmethod 

372 def _get_orm_crud_kv_pairs( 

373 cls, mapper, statement, kv_iterator, needs_to_be_cacheable 

374 ): 

375 core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs 

376 

377 for k, v in kv_iterator: 

378 k = coercions.expect(roles.DMLColumnRole, k) 

379 

380 if isinstance(k, str): 

381 desc = _entity_namespace_key(mapper, k, default=NO_VALUE) 

382 if desc is NO_VALUE: 

383 yield ( 

384 coercions.expect(roles.DMLColumnRole, k), 

385 ( 

386 coercions.expect( 

387 roles.ExpressionElementRole, 

388 v, 

389 type_=sqltypes.NullType(), 

390 is_crud=True, 

391 ) 

392 if needs_to_be_cacheable 

393 else v 

394 ), 

395 ) 

396 else: 

397 yield from core_get_crud_kv_pairs( 

398 statement, 

399 desc._bulk_update_tuples(v), 

400 needs_to_be_cacheable, 

401 ) 

402 elif "entity_namespace" in k._annotations: 

403 k_anno = k._annotations 

404 attr = _entity_namespace_key( 

405 k_anno["entity_namespace"], k_anno["proxy_key"] 

406 ) 

407 yield from core_get_crud_kv_pairs( 

408 statement, 

409 attr._bulk_update_tuples(v), 

410 needs_to_be_cacheable, 

411 ) 

412 else: 

413 yield ( 

414 k, 

415 ( 

416 v 

417 if not needs_to_be_cacheable 

418 else coercions.expect( 

419 roles.ExpressionElementRole, 

420 v, 

421 type_=sqltypes.NullType(), 

422 is_crud=True, 

423 ) 

424 ), 

425 ) 

426 

427 @classmethod 

428 def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): 

429 plugin_subject = statement._propagate_attrs["plugin_subject"] 

430 

431 if not plugin_subject or not plugin_subject.mapper: 

432 return UpdateDMLState._get_multi_crud_kv_pairs( 

433 statement, kv_iterator 

434 ) 

435 

436 return [ 

437 dict( 

438 cls._get_orm_crud_kv_pairs( 

439 plugin_subject.mapper, statement, value_dict.items(), False 

440 ) 

441 ) 

442 for value_dict in kv_iterator 

443 ] 

444 

445 @classmethod 

446 def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable): 

447 assert ( 

448 needs_to_be_cacheable 

449 ), "no test coverage for needs_to_be_cacheable=False" 

450 

451 plugin_subject = statement._propagate_attrs["plugin_subject"] 

452 

453 if not plugin_subject or not plugin_subject.mapper: 

454 return UpdateDMLState._get_crud_kv_pairs( 

455 statement, kv_iterator, needs_to_be_cacheable 

456 ) 

457 

458 return list( 

459 cls._get_orm_crud_kv_pairs( 

460 plugin_subject.mapper, 

461 statement, 

462 kv_iterator, 

463 needs_to_be_cacheable, 

464 ) 

465 ) 

466 

467 @classmethod 

468 def get_entity_description(cls, statement): 

469 ext_info = statement.table._annotations["parententity"] 

470 mapper = ext_info.mapper 

471 if ext_info.is_aliased_class: 

472 _label_name = ext_info.name 

473 else: 

474 _label_name = mapper.class_.__name__ 

475 

476 return { 

477 "name": _label_name, 

478 "type": mapper.class_, 

479 "expr": ext_info.entity, 

480 "entity": ext_info.entity, 

481 "table": mapper.local_table, 

482 } 

483 

484 @classmethod 

485 def get_returning_column_descriptions(cls, statement): 

486 def _ent_for_col(c): 

487 return c._annotations.get("parententity", None) 

488 

489 def _attr_for_col(c, ent): 

490 if ent is None: 

491 return c 

492 proxy_key = c._annotations.get("proxy_key", None) 

493 if not proxy_key: 

494 return c 

495 else: 

496 return getattr(ent.entity, proxy_key, c) 

497 

498 return [ 

499 { 

500 "name": c.key, 

501 "type": c.type, 

502 "expr": _attr_for_col(c, ent), 

503 "aliased": ent.is_aliased_class, 

504 "entity": ent.entity, 

505 } 

506 for c, ent in [ 

507 (c, _ent_for_col(c)) for c in statement._all_selected_columns 

508 ] 

509 ] 

510 

511 def _setup_orm_returning( 

512 self, 

513 compiler, 

514 orm_level_statement, 

515 dml_level_statement, 

516 dml_mapper, 

517 *, 

518 use_supplemental_cols=True, 

519 ): 

520 """establish ORM column handlers for an INSERT, UPDATE, or DELETE 

521 which uses explicit returning(). 

522 

523 called within compilation level create_for_statement. 

524 

525 The _return_orm_returning() method then receives the Result 

526 after the statement was executed, and applies ORM loading to the 

527 state that we first established here. 

528 

529 """ 

530 

531 if orm_level_statement._returning: 

532 fs = FromStatement( 

533 orm_level_statement._returning, 

534 dml_level_statement, 

535 _adapt_on_names=False, 

536 ) 

537 fs = fs.execution_options(**orm_level_statement._execution_options) 

538 fs = fs.options(*orm_level_statement._with_options) 

539 self.select_statement = fs 

540 self.from_statement_ctx = fsc = ( 

541 ORMFromStatementCompileState.create_for_statement(fs, compiler) 

542 ) 

543 fsc.setup_dml_returning_compile_state(dml_mapper) 

544 

545 dml_level_statement = dml_level_statement._generate() 

546 dml_level_statement._returning = () 

547 

548 cols_to_return = [c for c in fsc.primary_columns if c is not None] 

549 

550 # since we are splicing result sets together, make sure there 

551 # are columns of some kind returned in each result set 

552 if not cols_to_return: 

553 cols_to_return.extend(dml_mapper.primary_key) 

554 

555 if use_supplemental_cols: 

556 dml_level_statement = dml_level_statement.return_defaults( 

557 # this is a little weird looking, but by passing 

558 # primary key as the main list of cols, this tells 

559 # return_defaults to omit server-default cols (and 

560 # actually all cols, due to some weird thing we should 

561 # clean up in crud.py). 

562 # Since we have cols_to_return, just return what we asked 

563 # for (plus primary key, which ORM persistence needs since 

564 # we likely set bookkeeping=True here, which is another 

565 # whole thing...). We dont want to clutter the 

566 # statement up with lots of other cols the user didn't 

567 # ask for. see #9685 

568 *dml_mapper.primary_key, 

569 supplemental_cols=cols_to_return, 

570 ) 

571 else: 

572 dml_level_statement = dml_level_statement.returning( 

573 *cols_to_return 

574 ) 

575 

576 return dml_level_statement 

577 

578 @classmethod 

579 def _return_orm_returning( 

580 cls, 

581 session, 

582 statement, 

583 params, 

584 execution_options, 

585 bind_arguments, 

586 result, 

587 ): 

588 execution_context = result.context 

589 compile_state = execution_context.compiled.compile_state 

590 

591 if ( 

592 compile_state.from_statement_ctx 

593 and not compile_state.from_statement_ctx.compile_options._is_star 

594 ): 

595 load_options = execution_options.get( 

596 "_sa_orm_load_options", QueryContext.default_load_options 

597 ) 

598 

599 querycontext = QueryContext( 

600 compile_state.from_statement_ctx, 

601 compile_state.select_statement, 

602 params, 

603 session, 

604 load_options, 

605 execution_options, 

606 bind_arguments, 

607 ) 

608 return loading.instances(result, querycontext) 

609 else: 

610 return result 

611 

612 

613class BulkUDCompileState(ORMDMLState): 

614 class default_update_options(Options): 

615 _dml_strategy: DMLStrategyArgument = "auto" 

616 _synchronize_session: SynchronizeSessionArgument = "auto" 

617 _can_use_returning: bool = False 

618 _is_delete_using: bool = False 

619 _is_update_from: bool = False 

620 _autoflush: bool = True 

621 _subject_mapper: Optional[Mapper[Any]] = None 

622 _resolved_values = EMPTY_DICT 

623 _eval_condition = None 

624 _matched_rows = None 

625 _identity_token = None 

626 

627 @classmethod 

628 def can_use_returning( 

629 cls, 

630 dialect: Dialect, 

631 mapper: Mapper[Any], 

632 *, 

633 is_multitable: bool = False, 

634 is_update_from: bool = False, 

635 is_delete_using: bool = False, 

636 is_executemany: bool = False, 

637 ) -> bool: 

638 raise NotImplementedError() 

639 

640 @classmethod 

641 def orm_pre_session_exec( 

642 cls, 

643 session, 

644 statement, 

645 params, 

646 execution_options, 

647 bind_arguments, 

648 is_pre_event, 

649 ): 

650 ( 

651 update_options, 

652 execution_options, 

653 ) = BulkUDCompileState.default_update_options.from_execution_options( 

654 "_sa_orm_update_options", 

655 { 

656 "synchronize_session", 

657 "autoflush", 

658 "identity_token", 

659 "is_delete_using", 

660 "is_update_from", 

661 "dml_strategy", 

662 }, 

663 execution_options, 

664 statement._execution_options, 

665 ) 

666 bind_arguments["clause"] = statement 

667 try: 

668 plugin_subject = statement._propagate_attrs["plugin_subject"] 

669 except KeyError: 

670 assert False, "statement had 'orm' plugin but no plugin_subject" 

671 else: 

672 if plugin_subject: 

673 bind_arguments["mapper"] = plugin_subject.mapper 

674 update_options += {"_subject_mapper": plugin_subject.mapper} 

675 

676 if "parententity" not in statement.table._annotations: 

677 update_options += {"_dml_strategy": "core_only"} 

678 elif not isinstance(params, list): 

679 if update_options._dml_strategy == "auto": 

680 update_options += {"_dml_strategy": "orm"} 

681 elif update_options._dml_strategy == "bulk": 

682 raise sa_exc.InvalidRequestError( 

683 'Can\'t use "bulk" ORM insert strategy without ' 

684 "passing separate parameters" 

685 ) 

686 else: 

687 if update_options._dml_strategy == "auto": 

688 update_options += {"_dml_strategy": "bulk"} 

689 

690 sync = update_options._synchronize_session 

691 if sync is not None: 

692 if sync not in ("auto", "evaluate", "fetch", False): 

693 raise sa_exc.ArgumentError( 

694 "Valid strategies for session synchronization " 

695 "are 'auto', 'evaluate', 'fetch', False" 

696 ) 

697 if update_options._dml_strategy == "bulk" and sync == "fetch": 

698 raise sa_exc.InvalidRequestError( 

699 "The 'fetch' synchronization strategy is not available " 

700 "for 'bulk' ORM updates (i.e. multiple parameter sets)" 

701 ) 

702 

703 if not is_pre_event: 

704 if update_options._autoflush: 

705 session._autoflush() 

706 

707 if update_options._dml_strategy == "orm": 

708 if update_options._synchronize_session == "auto": 

709 update_options = cls._do_pre_synchronize_auto( 

710 session, 

711 statement, 

712 params, 

713 execution_options, 

714 bind_arguments, 

715 update_options, 

716 ) 

717 elif update_options._synchronize_session == "evaluate": 

718 update_options = cls._do_pre_synchronize_evaluate( 

719 session, 

720 statement, 

721 params, 

722 execution_options, 

723 bind_arguments, 

724 update_options, 

725 ) 

726 elif update_options._synchronize_session == "fetch": 

727 update_options = cls._do_pre_synchronize_fetch( 

728 session, 

729 statement, 

730 params, 

731 execution_options, 

732 bind_arguments, 

733 update_options, 

734 ) 

735 elif update_options._dml_strategy == "bulk": 

736 if update_options._synchronize_session == "auto": 

737 update_options += {"_synchronize_session": "evaluate"} 

738 

739 # indicators from the "pre exec" step that are then 

740 # added to the DML statement, which will also be part of the cache 

741 # key. The compile level create_for_statement() method will then 

742 # consume these at compiler time. 

743 statement = statement._annotate( 

744 { 

745 "synchronize_session": update_options._synchronize_session, 

746 "is_delete_using": update_options._is_delete_using, 

747 "is_update_from": update_options._is_update_from, 

748 "dml_strategy": update_options._dml_strategy, 

749 "can_use_returning": update_options._can_use_returning, 

750 } 

751 ) 

752 

753 return ( 

754 statement, 

755 util.immutabledict(execution_options).union( 

756 {"_sa_orm_update_options": update_options} 

757 ), 

758 ) 

759 

760 @classmethod 

761 def orm_setup_cursor_result( 

762 cls, 

763 session, 

764 statement, 

765 params, 

766 execution_options, 

767 bind_arguments, 

768 result, 

769 ): 

770 # this stage of the execution is called after the 

771 # do_orm_execute event hook. meaning for an extension like 

772 # horizontal sharding, this step happens *within* the horizontal 

773 # sharding event handler which calls session.execute() re-entrantly 

774 # and will occur for each backend individually. 

775 # the sharding extension then returns its own merged result from the 

776 # individual ones we return here. 

777 

778 update_options = execution_options["_sa_orm_update_options"] 

779 if update_options._dml_strategy == "orm": 

780 if update_options._synchronize_session == "evaluate": 

781 cls._do_post_synchronize_evaluate( 

782 session, statement, result, update_options 

783 ) 

784 elif update_options._synchronize_session == "fetch": 

785 cls._do_post_synchronize_fetch( 

786 session, statement, result, update_options 

787 ) 

788 elif update_options._dml_strategy == "bulk": 

789 if update_options._synchronize_session == "evaluate": 

790 cls._do_post_synchronize_bulk_evaluate( 

791 session, params, result, update_options 

792 ) 

793 return result 

794 

795 return cls._return_orm_returning( 

796 session, 

797 statement, 

798 params, 

799 execution_options, 

800 bind_arguments, 

801 result, 

802 ) 

803 

804 @classmethod 

805 def _adjust_for_extra_criteria(cls, global_attributes, ext_info): 

806 """Apply extra criteria filtering. 

807 

808 For all distinct single-table-inheritance mappers represented in the 

809 table being updated or deleted, produce additional WHERE criteria such 

810 that only the appropriate subtypes are selected from the total results. 

811 

812 Additionally, add WHERE criteria originating from LoaderCriteriaOptions 

813 collected from the statement. 

814 

815 """ 

816 

817 return_crit = () 

818 

819 adapter = ext_info._adapter if ext_info.is_aliased_class else None 

820 

821 if ( 

822 "additional_entity_criteria", 

823 ext_info.mapper, 

824 ) in global_attributes: 

825 return_crit += tuple( 

826 ae._resolve_where_criteria(ext_info) 

827 for ae in global_attributes[ 

828 ("additional_entity_criteria", ext_info.mapper) 

829 ] 

830 if ae.include_aliases or ae.entity is ext_info 

831 ) 

832 

833 if ext_info.mapper._single_table_criterion is not None: 

834 return_crit += (ext_info.mapper._single_table_criterion,) 

835 

836 if adapter: 

837 return_crit = tuple(adapter.traverse(crit) for crit in return_crit) 

838 

839 return return_crit 

840 

841 @classmethod 

842 def _interpret_returning_rows(cls, mapper, rows): 

843 """translate from local inherited table columns to base mapper 

844 primary key columns. 

845 

846 Joined inheritance mappers always establish the primary key in terms of 

847 the base table. When we UPDATE a sub-table, we can only get 

848 RETURNING for the sub-table's columns. 

849 

850 Here, we create a lookup from the local sub table's primary key 

851 columns to the base table PK columns so that we can get identity 

852 key values from RETURNING that's against the joined inheritance 

853 sub-table. 

854 

855 the complexity here is to support more than one level deep of 

856 inheritance, where we have to link columns to each other across 

857 the inheritance hierarchy. 

858 

859 """ 

860 

861 if mapper.local_table is not mapper.base_mapper.local_table: 

862 return rows 

863 

864 # this starts as a mapping of 

865 # local_pk_col: local_pk_col. 

866 # we will then iteratively rewrite the "value" of the dict with 

867 # each successive superclass column 

868 local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key} 

869 

870 for mp in mapper.iterate_to_root(): 

871 if mp.inherits is None: 

872 break 

873 elif mp.local_table is mp.inherits.local_table: 

874 continue 

875 

876 t_to_e = dict(mp._table_to_equated[mp.inherits.local_table]) 

877 col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]} 

878 for pk, super_ in local_pk_to_base_pk.items(): 

879 local_pk_to_base_pk[pk] = col_to_col[super_] 

880 

881 lookup = { 

882 local_pk_to_base_pk[lpk]: idx 

883 for idx, lpk in enumerate(mapper.local_table.primary_key) 

884 } 

885 primary_key_convert = [ 

886 lookup[bpk] for bpk in mapper.base_mapper.primary_key 

887 ] 

888 return [tuple(row[idx] for idx in primary_key_convert) for row in rows] 

889 

890 @classmethod 

891 def _get_matched_objects_on_criteria(cls, update_options, states): 

892 mapper = update_options._subject_mapper 

893 eval_condition = update_options._eval_condition 

894 

895 raw_data = [ 

896 (state.obj(), state, state.dict) 

897 for state in states 

898 if state.mapper.isa(mapper) and not state.expired 

899 ] 

900 

901 identity_token = update_options._identity_token 

902 if identity_token is not None: 

903 raw_data = [ 

904 (obj, state, dict_) 

905 for obj, state, dict_ in raw_data 

906 if state.identity_token == identity_token 

907 ] 

908 

909 result = [] 

910 for obj, state, dict_ in raw_data: 

911 evaled_condition = eval_condition(obj) 

912 

913 # caution: don't use "in ()" or == here, _EXPIRE_OBJECT 

914 # evaluates as True for all comparisons 

915 if ( 

916 evaled_condition is True 

917 or evaled_condition is evaluator._EXPIRED_OBJECT 

918 ): 

919 result.append( 

920 ( 

921 obj, 

922 state, 

923 dict_, 

924 evaled_condition is evaluator._EXPIRED_OBJECT, 

925 ) 

926 ) 

927 return result 

928 

929 @classmethod 

930 def _eval_condition_from_statement(cls, update_options, statement): 

931 mapper = update_options._subject_mapper 

932 target_cls = mapper.class_ 

933 

934 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) 

935 crit = () 

936 if statement._where_criteria: 

937 crit += statement._where_criteria 

938 

939 global_attributes = {} 

940 for opt in statement._with_options: 

941 if opt._is_criteria_option: 

942 opt.get_global_criteria(global_attributes) 

943 

944 if global_attributes: 

945 crit += cls._adjust_for_extra_criteria(global_attributes, mapper) 

946 

947 if crit: 

948 eval_condition = evaluator_compiler.process(*crit) 

949 else: 

950 # workaround for mypy https://github.com/python/mypy/issues/14027 

951 def _eval_condition(obj): 

952 return True 

953 

954 eval_condition = _eval_condition 

955 

956 return eval_condition 

957 

958 @classmethod 

959 def _do_pre_synchronize_auto( 

960 cls, 

961 session, 

962 statement, 

963 params, 

964 execution_options, 

965 bind_arguments, 

966 update_options, 

967 ): 

968 """setup auto sync strategy 

969 

970 

971 "auto" checks if we can use "evaluate" first, then falls back 

972 to "fetch" 

973 

974 evaluate is vastly more efficient for the common case 

975 where session is empty, only has a few objects, and the UPDATE 

976 statement can potentially match thousands/millions of rows. 

977 

978 OTOH more complex criteria that fails to work with "evaluate" 

979 we would hope usually correlates with fewer net rows. 

980 

981 """ 

982 

983 try: 

984 eval_condition = cls._eval_condition_from_statement( 

985 update_options, statement 

986 ) 

987 

988 except evaluator.UnevaluatableError: 

989 pass 

990 else: 

991 return update_options + { 

992 "_eval_condition": eval_condition, 

993 "_synchronize_session": "evaluate", 

994 } 

995 

996 update_options += {"_synchronize_session": "fetch"} 

997 return cls._do_pre_synchronize_fetch( 

998 session, 

999 statement, 

1000 params, 

1001 execution_options, 

1002 bind_arguments, 

1003 update_options, 

1004 ) 

1005 

1006 @classmethod 

1007 def _do_pre_synchronize_evaluate( 

1008 cls, 

1009 session, 

1010 statement, 

1011 params, 

1012 execution_options, 

1013 bind_arguments, 

1014 update_options, 

1015 ): 

1016 try: 

1017 eval_condition = cls._eval_condition_from_statement( 

1018 update_options, statement 

1019 ) 

1020 

1021 except evaluator.UnevaluatableError as err: 

1022 raise sa_exc.InvalidRequestError( 

1023 'Could not evaluate current criteria in Python: "%s". ' 

1024 "Specify 'fetch' or False for the " 

1025 "synchronize_session execution option." % err 

1026 ) from err 

1027 

1028 return update_options + { 

1029 "_eval_condition": eval_condition, 

1030 } 

1031 

1032 @classmethod 

1033 def _get_resolved_values(cls, mapper, statement): 

1034 if statement._multi_values: 

1035 return [] 

1036 elif statement._ordered_values: 

1037 return list(statement._ordered_values) 

1038 elif statement._values: 

1039 return list(statement._values.items()) 

1040 else: 

1041 return [] 

1042 

1043 @classmethod 

1044 def _resolved_keys_as_propnames(cls, mapper, resolved_values): 

1045 values = [] 

1046 for k, v in resolved_values: 

1047 if mapper and isinstance(k, expression.ColumnElement): 

1048 try: 

1049 attr = mapper._columntoproperty[k] 

1050 except orm_exc.UnmappedColumnError: 

1051 pass 

1052 else: 

1053 values.append((attr.key, v)) 

1054 else: 

1055 raise sa_exc.InvalidRequestError( 

1056 "Attribute name not found, can't be " 

1057 "synchronized back to objects: %r" % k 

1058 ) 

1059 return values 

1060 

1061 @classmethod 

1062 def _do_pre_synchronize_fetch( 

1063 cls, 

1064 session, 

1065 statement, 

1066 params, 

1067 execution_options, 

1068 bind_arguments, 

1069 update_options, 

1070 ): 

1071 mapper = update_options._subject_mapper 

1072 

1073 select_stmt = ( 

1074 select(*(mapper.primary_key + (mapper.select_identity_token,))) 

1075 .select_from(mapper) 

1076 .options(*statement._with_options) 

1077 ) 

1078 select_stmt._where_criteria = statement._where_criteria 

1079 

1080 # conditionally run the SELECT statement for pre-fetch, testing the 

1081 # "bind" for if we can use RETURNING or not using the do_orm_execute 

1082 # event. If RETURNING is available, the do_orm_execute event 

1083 # will cancel the SELECT from being actually run. 

1084 # 

1085 # The way this is organized seems strange, why don't we just 

1086 # call can_use_returning() before invoking the statement and get 

1087 # answer?, why does this go through the whole execute phase using an 

1088 # event? Answer: because we are integrating with extensions such 

1089 # as the horizontal sharding extention that "multiplexes" an individual 

1090 # statement run through multiple engines, and it uses 

1091 # do_orm_execute() to do that. 

1092 

1093 can_use_returning = None 

1094 

1095 def skip_for_returning(orm_context: ORMExecuteState) -> Any: 

1096 bind = orm_context.session.get_bind(**orm_context.bind_arguments) 

1097 nonlocal can_use_returning 

1098 

1099 per_bind_result = cls.can_use_returning( 

1100 bind.dialect, 

1101 mapper, 

1102 is_update_from=update_options._is_update_from, 

1103 is_delete_using=update_options._is_delete_using, 

1104 is_executemany=orm_context.is_executemany, 

1105 ) 

1106 

1107 if can_use_returning is not None: 

1108 if can_use_returning != per_bind_result: 

1109 raise sa_exc.InvalidRequestError( 

1110 "For synchronize_session='fetch', can't mix multiple " 

1111 "backends where some support RETURNING and others " 

1112 "don't" 

1113 ) 

1114 elif orm_context.is_executemany and not per_bind_result: 

1115 raise sa_exc.InvalidRequestError( 

1116 "For synchronize_session='fetch', can't use multiple " 

1117 "parameter sets in ORM mode, which this backend does not " 

1118 "support with RETURNING" 

1119 ) 

1120 else: 

1121 can_use_returning = per_bind_result 

1122 

1123 if per_bind_result: 

1124 return _result.null_result() 

1125 else: 

1126 return None 

1127 

1128 result = session.execute( 

1129 select_stmt, 

1130 params, 

1131 execution_options=execution_options, 

1132 bind_arguments=bind_arguments, 

1133 _add_event=skip_for_returning, 

1134 ) 

1135 matched_rows = result.fetchall() 

1136 

1137 return update_options + { 

1138 "_matched_rows": matched_rows, 

1139 "_can_use_returning": can_use_returning, 

1140 } 

1141 

1142 

1143@CompileState.plugin_for("orm", "insert") 

1144class BulkORMInsert(ORMDMLState, InsertDMLState): 

1145 class default_insert_options(Options): 

1146 _dml_strategy: DMLStrategyArgument = "auto" 

1147 _render_nulls: bool = False 

1148 _return_defaults: bool = False 

1149 _subject_mapper: Optional[Mapper[Any]] = None 

1150 _autoflush: bool = True 

1151 _populate_existing: bool = False 

1152 

1153 select_statement: Optional[FromStatement] = None 

1154 

1155 @classmethod 

1156 def orm_pre_session_exec( 

1157 cls, 

1158 session, 

1159 statement, 

1160 params, 

1161 execution_options, 

1162 bind_arguments, 

1163 is_pre_event, 

1164 ): 

1165 ( 

1166 insert_options, 

1167 execution_options, 

1168 ) = BulkORMInsert.default_insert_options.from_execution_options( 

1169 "_sa_orm_insert_options", 

1170 {"dml_strategy", "autoflush", "populate_existing", "render_nulls"}, 

1171 execution_options, 

1172 statement._execution_options, 

1173 ) 

1174 bind_arguments["clause"] = statement 

1175 try: 

1176 plugin_subject = statement._propagate_attrs["plugin_subject"] 

1177 except KeyError: 

1178 assert False, "statement had 'orm' plugin but no plugin_subject" 

1179 else: 

1180 if plugin_subject: 

1181 bind_arguments["mapper"] = plugin_subject.mapper 

1182 insert_options += {"_subject_mapper": plugin_subject.mapper} 

1183 

1184 if not params: 

1185 if insert_options._dml_strategy == "auto": 

1186 insert_options += {"_dml_strategy": "orm"} 

1187 elif insert_options._dml_strategy == "bulk": 

1188 raise sa_exc.InvalidRequestError( 

1189 'Can\'t use "bulk" ORM insert strategy without ' 

1190 "passing separate parameters" 

1191 ) 

1192 else: 

1193 if insert_options._dml_strategy == "auto": 

1194 insert_options += {"_dml_strategy": "bulk"} 

1195 

1196 if insert_options._dml_strategy != "raw": 

1197 # for ORM object loading, like ORMContext, we have to disable 

1198 # result set adapt_to_context, because we will be generating a 

1199 # new statement with specific columns that's cached inside of 

1200 # an ORMFromStatementCompileState, which we will re-use for 

1201 # each result. 

1202 if not execution_options: 

1203 execution_options = context._orm_load_exec_options 

1204 else: 

1205 execution_options = execution_options.union( 

1206 context._orm_load_exec_options 

1207 ) 

1208 

1209 if not is_pre_event and insert_options._autoflush: 

1210 session._autoflush() 

1211 

1212 statement = statement._annotate( 

1213 {"dml_strategy": insert_options._dml_strategy} 

1214 ) 

1215 

1216 return ( 

1217 statement, 

1218 util.immutabledict(execution_options).union( 

1219 {"_sa_orm_insert_options": insert_options} 

1220 ), 

1221 ) 

1222 

1223 @classmethod 

1224 def orm_execute_statement( 

1225 cls, 

1226 session: Session, 

1227 statement: dml.Insert, 

1228 params: _CoreAnyExecuteParams, 

1229 execution_options: OrmExecuteOptionsParameter, 

1230 bind_arguments: _BindArguments, 

1231 conn: Connection, 

1232 ) -> _result.Result: 

1233 insert_options = execution_options.get( 

1234 "_sa_orm_insert_options", cls.default_insert_options 

1235 ) 

1236 

1237 if insert_options._dml_strategy not in ( 

1238 "raw", 

1239 "bulk", 

1240 "orm", 

1241 "auto", 

1242 ): 

1243 raise sa_exc.ArgumentError( 

1244 "Valid strategies for ORM insert strategy " 

1245 "are 'raw', 'orm', 'bulk', 'auto" 

1246 ) 

1247 

1248 result: _result.Result[Unpack[TupleAny]] 

1249 

1250 if insert_options._dml_strategy == "raw": 

1251 result = conn.execute( 

1252 statement, params or {}, execution_options=execution_options 

1253 ) 

1254 return result 

1255 

1256 if insert_options._dml_strategy == "bulk": 

1257 mapper = insert_options._subject_mapper 

1258 

1259 if ( 

1260 statement._post_values_clause is not None 

1261 and mapper._multiple_persistence_tables 

1262 ): 

1263 raise sa_exc.InvalidRequestError( 

1264 "bulk INSERT with a 'post values' clause " 

1265 "(typically upsert) not supported for multi-table " 

1266 f"mapper {mapper}" 

1267 ) 

1268 

1269 assert mapper is not None 

1270 assert session._transaction is not None 

1271 result = _bulk_insert( 

1272 mapper, 

1273 cast( 

1274 "Iterable[Dict[str, Any]]", 

1275 [params] if isinstance(params, dict) else params, 

1276 ), 

1277 session._transaction, 

1278 isstates=False, 

1279 return_defaults=insert_options._return_defaults, 

1280 render_nulls=insert_options._render_nulls, 

1281 use_orm_insert_stmt=statement, 

1282 execution_options=execution_options, 

1283 ) 

1284 elif insert_options._dml_strategy == "orm": 

1285 result = conn.execute( 

1286 statement, params or {}, execution_options=execution_options 

1287 ) 

1288 else: 

1289 raise AssertionError() 

1290 

1291 if not bool(statement._returning): 

1292 return result 

1293 

1294 if insert_options._populate_existing: 

1295 load_options = execution_options.get( 

1296 "_sa_orm_load_options", QueryContext.default_load_options 

1297 ) 

1298 load_options += {"_populate_existing": True} 

1299 execution_options = execution_options.union( 

1300 {"_sa_orm_load_options": load_options} 

1301 ) 

1302 

1303 return cls._return_orm_returning( 

1304 session, 

1305 statement, 

1306 params, 

1307 execution_options, 

1308 bind_arguments, 

1309 result, 

1310 ) 

1311 

1312 @classmethod 

1313 def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert: 

1314 self = cast( 

1315 BulkORMInsert, 

1316 super().create_for_statement(statement, compiler, **kw), 

1317 ) 

1318 

1319 if compiler is not None: 

1320 toplevel = not compiler.stack 

1321 else: 

1322 toplevel = True 

1323 if not toplevel: 

1324 return self 

1325 

1326 mapper = statement._propagate_attrs["plugin_subject"] 

1327 dml_strategy = statement._annotations.get("dml_strategy", "raw") 

1328 if dml_strategy == "bulk": 

1329 self._setup_for_bulk_insert(compiler) 

1330 elif dml_strategy == "orm": 

1331 self._setup_for_orm_insert(compiler, mapper) 

1332 

1333 return self 

1334 

1335 @classmethod 

1336 def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict): 

1337 return { 

1338 col.key if col is not None else k: v 

1339 for col, k, v in ( 

1340 (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items() 

1341 ) 

1342 } 

1343 

1344 def _setup_for_orm_insert(self, compiler, mapper): 

1345 statement = orm_level_statement = cast(dml.Insert, self.statement) 

1346 

1347 statement = self._setup_orm_returning( 

1348 compiler, 

1349 orm_level_statement, 

1350 statement, 

1351 dml_mapper=mapper, 

1352 use_supplemental_cols=False, 

1353 ) 

1354 self.statement = statement 

1355 

1356 def _setup_for_bulk_insert(self, compiler): 

1357 """establish an INSERT statement within the context of 

1358 bulk insert. 

1359 

1360 This method will be within the "conn.execute()" call that is invoked 

1361 by persistence._emit_insert_statement(). 

1362 

1363 """ 

1364 statement = orm_level_statement = cast(dml.Insert, self.statement) 

1365 an = statement._annotations 

1366 

1367 emit_insert_table, emit_insert_mapper = ( 

1368 an["_emit_insert_table"], 

1369 an["_emit_insert_mapper"], 

1370 ) 

1371 

1372 statement = statement._clone() 

1373 

1374 statement.table = emit_insert_table 

1375 if self._dict_parameters: 

1376 self._dict_parameters = { 

1377 col: val 

1378 for col, val in self._dict_parameters.items() 

1379 if col.table is emit_insert_table 

1380 } 

1381 

1382 statement = self._setup_orm_returning( 

1383 compiler, 

1384 orm_level_statement, 

1385 statement, 

1386 dml_mapper=emit_insert_mapper, 

1387 use_supplemental_cols=True, 

1388 ) 

1389 

1390 if ( 

1391 self.from_statement_ctx is not None 

1392 and self.from_statement_ctx.compile_options._is_star 

1393 ): 

1394 raise sa_exc.CompileError( 

1395 "Can't use RETURNING * with bulk ORM INSERT. " 

1396 "Please use a different INSERT form, such as INSERT..VALUES " 

1397 "or INSERT with a Core Connection" 

1398 ) 

1399 

1400 self.statement = statement 

1401 

1402 

1403@CompileState.plugin_for("orm", "update") 

1404class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): 

1405 @classmethod 

1406 def create_for_statement(cls, statement, compiler, **kw): 

1407 self = cls.__new__(cls) 

1408 

1409 dml_strategy = statement._annotations.get( 

1410 "dml_strategy", "unspecified" 

1411 ) 

1412 

1413 toplevel = not compiler.stack 

1414 

1415 if toplevel and dml_strategy == "bulk": 

1416 self._setup_for_bulk_update(statement, compiler) 

1417 elif ( 

1418 dml_strategy == "core_only" 

1419 or dml_strategy == "unspecified" 

1420 and "parententity" not in statement.table._annotations 

1421 ): 

1422 UpdateDMLState.__init__(self, statement, compiler, **kw) 

1423 elif not toplevel or dml_strategy in ("orm", "unspecified"): 

1424 self._setup_for_orm_update(statement, compiler) 

1425 

1426 return self 

1427 

1428 def _setup_for_orm_update(self, statement, compiler, **kw): 

1429 orm_level_statement = statement 

1430 

1431 toplevel = not compiler.stack 

1432 

1433 ext_info = statement.table._annotations["parententity"] 

1434 

1435 self.mapper = mapper = ext_info.mapper 

1436 

1437 self._resolved_values = self._get_resolved_values(mapper, statement) 

1438 

1439 self._init_global_attributes( 

1440 statement, 

1441 compiler, 

1442 toplevel=toplevel, 

1443 process_criteria_for_toplevel=toplevel, 

1444 ) 

1445 

1446 if statement._values: 

1447 self._resolved_values = dict(self._resolved_values) 

1448 

1449 new_stmt = statement._clone() 

1450 

1451 # note if the statement has _multi_values, these 

1452 # are passed through to the new statement, which will then raise 

1453 # InvalidRequestError because UPDATE doesn't support multi_values 

1454 # right now. 

1455 if statement._ordered_values: 

1456 new_stmt._ordered_values = self._resolved_values 

1457 elif statement._values: 

1458 new_stmt._values = self._resolved_values 

1459 

1460 new_crit = self._adjust_for_extra_criteria( 

1461 self.global_attributes, mapper 

1462 ) 

1463 if new_crit: 

1464 new_stmt = new_stmt.where(*new_crit) 

1465 

1466 # if we are against a lambda statement we might not be the 

1467 # topmost object that received per-execute annotations 

1468 

1469 # do this first as we need to determine if there is 

1470 # UPDATE..FROM 

1471 

1472 UpdateDMLState.__init__(self, new_stmt, compiler, **kw) 

1473 

1474 use_supplemental_cols = False 

1475 

1476 if not toplevel: 

1477 synchronize_session = None 

1478 else: 

1479 synchronize_session = compiler._annotations.get( 

1480 "synchronize_session", None 

1481 ) 

1482 can_use_returning = compiler._annotations.get( 

1483 "can_use_returning", None 

1484 ) 

1485 if can_use_returning is not False: 

1486 # even though pre_exec has determined basic 

1487 # can_use_returning for the dialect, if we are to use 

1488 # RETURNING we need to run can_use_returning() at this level 

1489 # unconditionally because is_delete_using was not known 

1490 # at the pre_exec level 

1491 can_use_returning = ( 

1492 synchronize_session == "fetch" 

1493 and self.can_use_returning( 

1494 compiler.dialect, mapper, is_multitable=self.is_multitable 

1495 ) 

1496 ) 

1497 

1498 if synchronize_session == "fetch" and can_use_returning: 

1499 use_supplemental_cols = True 

1500 

1501 # NOTE: we might want to RETURNING the actual columns to be 

1502 # synchronized also. however this is complicated and difficult 

1503 # to align against the behavior of "evaluate". Additionally, 

1504 # in a large number (if not the majority) of cases, we have the 

1505 # "evaluate" answer, usually a fixed value, in memory already and 

1506 # there's no need to re-fetch the same value 

1507 # over and over again. so perhaps if it could be RETURNING just 

1508 # the elements that were based on a SQL expression and not 

1509 # a constant. For now it doesn't quite seem worth it 

1510 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) 

1511 

1512 if toplevel: 

1513 new_stmt = self._setup_orm_returning( 

1514 compiler, 

1515 orm_level_statement, 

1516 new_stmt, 

1517 dml_mapper=mapper, 

1518 use_supplemental_cols=use_supplemental_cols, 

1519 ) 

1520 

1521 self.statement = new_stmt 

1522 

1523 def _setup_for_bulk_update(self, statement, compiler, **kw): 

1524 """establish an UPDATE statement within the context of 

1525 bulk insert. 

1526 

1527 This method will be within the "conn.execute()" call that is invoked 

1528 by persistence._emit_update_statement(). 

1529 

1530 """ 

1531 statement = cast(dml.Update, statement) 

1532 an = statement._annotations 

1533 

1534 emit_update_table, _ = ( 

1535 an["_emit_update_table"], 

1536 an["_emit_update_mapper"], 

1537 ) 

1538 

1539 statement = statement._clone() 

1540 statement.table = emit_update_table 

1541 

1542 UpdateDMLState.__init__(self, statement, compiler, **kw) 

1543 

1544 if self._ordered_values: 

1545 raise sa_exc.InvalidRequestError( 

1546 "bulk ORM UPDATE does not support ordered_values() for " 

1547 "custom UPDATE statements with bulk parameter sets. Use a " 

1548 "non-bulk UPDATE statement or use values()." 

1549 ) 

1550 

1551 if self._dict_parameters: 

1552 self._dict_parameters = { 

1553 col: val 

1554 for col, val in self._dict_parameters.items() 

1555 if col.table is emit_update_table 

1556 } 

1557 self.statement = statement 

1558 

1559 @classmethod 

1560 def orm_execute_statement( 

1561 cls, 

1562 session: Session, 

1563 statement: dml.Update, 

1564 params: _CoreAnyExecuteParams, 

1565 execution_options: OrmExecuteOptionsParameter, 

1566 bind_arguments: _BindArguments, 

1567 conn: Connection, 

1568 ) -> _result.Result: 

1569 update_options = execution_options.get( 

1570 "_sa_orm_update_options", cls.default_update_options 

1571 ) 

1572 

1573 if update_options._dml_strategy not in ( 

1574 "orm", 

1575 "auto", 

1576 "bulk", 

1577 "core_only", 

1578 ): 

1579 raise sa_exc.ArgumentError( 

1580 "Valid strategies for ORM UPDATE strategy " 

1581 "are 'orm', 'auto', 'bulk', 'core_only'" 

1582 ) 

1583 

1584 result: _result.Result[Unpack[TupleAny]] 

1585 

1586 if update_options._dml_strategy == "bulk": 

1587 enable_check_rowcount = not statement._where_criteria 

1588 

1589 assert update_options._synchronize_session != "fetch" 

1590 

1591 if ( 

1592 statement._where_criteria 

1593 and update_options._synchronize_session == "evaluate" 

1594 ): 

1595 raise sa_exc.InvalidRequestError( 

1596 "bulk synchronize of persistent objects not supported " 

1597 "when using bulk update with additional WHERE " 

1598 "criteria right now. add synchronize_session=None " 

1599 "execution option to bypass synchronize of persistent " 

1600 "objects." 

1601 ) 

1602 mapper = update_options._subject_mapper 

1603 assert mapper is not None 

1604 assert session._transaction is not None 

1605 result = _bulk_update( 

1606 mapper, 

1607 cast( 

1608 "Iterable[Dict[str, Any]]", 

1609 [params] if isinstance(params, dict) else params, 

1610 ), 

1611 session._transaction, 

1612 isstates=False, 

1613 update_changed_only=False, 

1614 use_orm_update_stmt=statement, 

1615 enable_check_rowcount=enable_check_rowcount, 

1616 ) 

1617 return cls.orm_setup_cursor_result( 

1618 session, 

1619 statement, 

1620 params, 

1621 execution_options, 

1622 bind_arguments, 

1623 result, 

1624 ) 

1625 else: 

1626 return super().orm_execute_statement( 

1627 session, 

1628 statement, 

1629 params, 

1630 execution_options, 

1631 bind_arguments, 

1632 conn, 

1633 ) 

1634 

1635 @classmethod 

1636 def can_use_returning( 

1637 cls, 

1638 dialect: Dialect, 

1639 mapper: Mapper[Any], 

1640 *, 

1641 is_multitable: bool = False, 

1642 is_update_from: bool = False, 

1643 is_delete_using: bool = False, 

1644 is_executemany: bool = False, 

1645 ) -> bool: 

1646 # normal answer for "should we use RETURNING" at all. 

1647 normal_answer = ( 

1648 dialect.update_returning and mapper.local_table.implicit_returning 

1649 ) 

1650 if not normal_answer: 

1651 return False 

1652 

1653 if is_executemany: 

1654 return dialect.update_executemany_returning 

1655 

1656 # these workarounds are currently hypothetical for UPDATE, 

1657 # unlike DELETE where they impact MariaDB 

1658 if is_update_from: 

1659 return dialect.update_returning_multifrom 

1660 

1661 elif is_multitable and not dialect.update_returning_multifrom: 

1662 raise sa_exc.CompileError( 

1663 f'Dialect "{dialect.name}" does not support RETURNING ' 

1664 "with UPDATE..FROM; for synchronize_session='fetch', " 

1665 "please add the additional execution option " 

1666 "'is_update_from=True' to the statement to indicate that " 

1667 "a separate SELECT should be used for this backend." 

1668 ) 

1669 

1670 return True 

1671 

1672 @classmethod 

1673 def _do_post_synchronize_bulk_evaluate( 

1674 cls, session, params, result, update_options 

1675 ): 

1676 if not params: 

1677 return 

1678 

1679 mapper = update_options._subject_mapper 

1680 pk_keys = [prop.key for prop in mapper._identity_key_props] 

1681 

1682 identity_map = session.identity_map 

1683 

1684 for param in params: 

1685 identity_key = mapper.identity_key_from_primary_key( 

1686 (param[key] for key in pk_keys), 

1687 update_options._identity_token, 

1688 ) 

1689 state = identity_map.fast_get_state(identity_key) 

1690 if not state: 

1691 continue 

1692 

1693 evaluated_keys = set(param).difference(pk_keys) 

1694 

1695 dict_ = state.dict 

1696 # only evaluate unmodified attributes 

1697 to_evaluate = state.unmodified.intersection(evaluated_keys) 

1698 for key in to_evaluate: 

1699 if key in dict_: 

1700 dict_[key] = param[key] 

1701 

1702 state.manager.dispatch.refresh(state, None, to_evaluate) 

1703 

1704 state._commit(dict_, list(to_evaluate)) 

1705 

1706 # attributes that were formerly modified instead get expired. 

1707 # this only gets hit if the session had pending changes 

1708 # and autoflush were set to False. 

1709 to_expire = evaluated_keys.intersection(dict_).difference( 

1710 to_evaluate 

1711 ) 

1712 if to_expire: 

1713 state._expire_attributes(dict_, to_expire) 

1714 

1715 @classmethod 

1716 def _do_post_synchronize_evaluate( 

1717 cls, session, statement, result, update_options 

1718 ): 

1719 matched_objects = cls._get_matched_objects_on_criteria( 

1720 update_options, 

1721 session.identity_map.all_states(), 

1722 ) 

1723 

1724 cls._apply_update_set_values_to_objects( 

1725 session, 

1726 update_options, 

1727 statement, 

1728 [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], 

1729 ) 

1730 

1731 @classmethod 

1732 def _do_post_synchronize_fetch( 

1733 cls, session, statement, result, update_options 

1734 ): 

1735 target_mapper = update_options._subject_mapper 

1736 

1737 returned_defaults_rows = result.returned_defaults_rows 

1738 if returned_defaults_rows: 

1739 pk_rows = cls._interpret_returning_rows( 

1740 target_mapper, returned_defaults_rows 

1741 ) 

1742 

1743 matched_rows = [ 

1744 tuple(row) + (update_options._identity_token,) 

1745 for row in pk_rows 

1746 ] 

1747 else: 

1748 matched_rows = update_options._matched_rows 

1749 

1750 objs = [ 

1751 session.identity_map[identity_key] 

1752 for identity_key in [ 

1753 target_mapper.identity_key_from_primary_key( 

1754 list(primary_key), 

1755 identity_token=identity_token, 

1756 ) 

1757 for primary_key, identity_token in [ 

1758 (row[0:-1], row[-1]) for row in matched_rows 

1759 ] 

1760 if update_options._identity_token is None 

1761 or identity_token == update_options._identity_token 

1762 ] 

1763 if identity_key in session.identity_map 

1764 ] 

1765 

1766 if not objs: 

1767 return 

1768 

1769 cls._apply_update_set_values_to_objects( 

1770 session, 

1771 update_options, 

1772 statement, 

1773 [ 

1774 ( 

1775 obj, 

1776 attributes.instance_state(obj), 

1777 attributes.instance_dict(obj), 

1778 ) 

1779 for obj in objs 

1780 ], 

1781 ) 

1782 

1783 @classmethod 

1784 def _apply_update_set_values_to_objects( 

1785 cls, session, update_options, statement, matched_objects 

1786 ): 

1787 """apply values to objects derived from an update statement, e.g. 

1788 UPDATE..SET <values> 

1789 

1790 """ 

1791 mapper = update_options._subject_mapper 

1792 target_cls = mapper.class_ 

1793 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) 

1794 resolved_values = cls._get_resolved_values(mapper, statement) 

1795 resolved_keys_as_propnames = cls._resolved_keys_as_propnames( 

1796 mapper, resolved_values 

1797 ) 

1798 value_evaluators = {} 

1799 for key, value in resolved_keys_as_propnames: 

1800 try: 

1801 _evaluator = evaluator_compiler.process( 

1802 coercions.expect(roles.ExpressionElementRole, value) 

1803 ) 

1804 except evaluator.UnevaluatableError: 

1805 pass 

1806 else: 

1807 value_evaluators[key] = _evaluator 

1808 

1809 evaluated_keys = list(value_evaluators.keys()) 

1810 attrib = {k for k, v in resolved_keys_as_propnames} 

1811 

1812 states = set() 

1813 for obj, state, dict_ in matched_objects: 

1814 to_evaluate = state.unmodified.intersection(evaluated_keys) 

1815 

1816 for key in to_evaluate: 

1817 if key in dict_: 

1818 # only run eval for attributes that are present. 

1819 dict_[key] = value_evaluators[key](obj) 

1820 

1821 state.manager.dispatch.refresh(state, None, to_evaluate) 

1822 

1823 state._commit(dict_, list(to_evaluate)) 

1824 

1825 # attributes that were formerly modified instead get expired. 

1826 # this only gets hit if the session had pending changes 

1827 # and autoflush were set to False. 

1828 to_expire = attrib.intersection(dict_).difference(to_evaluate) 

1829 if to_expire: 

1830 state._expire_attributes(dict_, to_expire) 

1831 

1832 states.add(state) 

1833 session._register_altered(states) 

1834 

1835 

1836@CompileState.plugin_for("orm", "delete") 

1837class BulkORMDelete(BulkUDCompileState, DeleteDMLState): 

1838 @classmethod 

1839 def create_for_statement(cls, statement, compiler, **kw): 

1840 self = cls.__new__(cls) 

1841 

1842 dml_strategy = statement._annotations.get( 

1843 "dml_strategy", "unspecified" 

1844 ) 

1845 

1846 if ( 

1847 dml_strategy == "core_only" 

1848 or dml_strategy == "unspecified" 

1849 and "parententity" not in statement.table._annotations 

1850 ): 

1851 DeleteDMLState.__init__(self, statement, compiler, **kw) 

1852 return self 

1853 

1854 toplevel = not compiler.stack 

1855 

1856 orm_level_statement = statement 

1857 

1858 ext_info = statement.table._annotations["parententity"] 

1859 self.mapper = mapper = ext_info.mapper 

1860 

1861 self._init_global_attributes( 

1862 statement, 

1863 compiler, 

1864 toplevel=toplevel, 

1865 process_criteria_for_toplevel=toplevel, 

1866 ) 

1867 

1868 new_stmt = statement._clone() 

1869 

1870 new_crit = cls._adjust_for_extra_criteria( 

1871 self.global_attributes, mapper 

1872 ) 

1873 if new_crit: 

1874 new_stmt = new_stmt.where(*new_crit) 

1875 

1876 # do this first as we need to determine if there is 

1877 # DELETE..FROM 

1878 DeleteDMLState.__init__(self, new_stmt, compiler, **kw) 

1879 

1880 use_supplemental_cols = False 

1881 

1882 if not toplevel: 

1883 synchronize_session = None 

1884 else: 

1885 synchronize_session = compiler._annotations.get( 

1886 "synchronize_session", None 

1887 ) 

1888 can_use_returning = compiler._annotations.get( 

1889 "can_use_returning", None 

1890 ) 

1891 if can_use_returning is not False: 

1892 # even though pre_exec has determined basic 

1893 # can_use_returning for the dialect, if we are to use 

1894 # RETURNING we need to run can_use_returning() at this level 

1895 # unconditionally because is_delete_using was not known 

1896 # at the pre_exec level 

1897 can_use_returning = ( 

1898 synchronize_session == "fetch" 

1899 and self.can_use_returning( 

1900 compiler.dialect, 

1901 mapper, 

1902 is_multitable=self.is_multitable, 

1903 is_delete_using=compiler._annotations.get( 

1904 "is_delete_using", False 

1905 ), 

1906 ) 

1907 ) 

1908 

1909 if can_use_returning: 

1910 use_supplemental_cols = True 

1911 

1912 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) 

1913 

1914 if toplevel: 

1915 new_stmt = self._setup_orm_returning( 

1916 compiler, 

1917 orm_level_statement, 

1918 new_stmt, 

1919 dml_mapper=mapper, 

1920 use_supplemental_cols=use_supplemental_cols, 

1921 ) 

1922 

1923 self.statement = new_stmt 

1924 

1925 return self 

1926 

1927 @classmethod 

1928 def orm_execute_statement( 

1929 cls, 

1930 session: Session, 

1931 statement: dml.Delete, 

1932 params: _CoreAnyExecuteParams, 

1933 execution_options: OrmExecuteOptionsParameter, 

1934 bind_arguments: _BindArguments, 

1935 conn: Connection, 

1936 ) -> _result.Result: 

1937 update_options = execution_options.get( 

1938 "_sa_orm_update_options", cls.default_update_options 

1939 ) 

1940 

1941 if update_options._dml_strategy == "bulk": 

1942 raise sa_exc.InvalidRequestError( 

1943 "Bulk ORM DELETE not supported right now. " 

1944 "Statement may be invoked at the " 

1945 "Core level using " 

1946 "session.connection().execute(stmt, parameters)" 

1947 ) 

1948 

1949 if update_options._dml_strategy not in ("orm", "auto", "core_only"): 

1950 raise sa_exc.ArgumentError( 

1951 "Valid strategies for ORM DELETE strategy are 'orm', 'auto', " 

1952 "'core_only'" 

1953 ) 

1954 

1955 return super().orm_execute_statement( 

1956 session, statement, params, execution_options, bind_arguments, conn 

1957 ) 

1958 

1959 @classmethod 

1960 def can_use_returning( 

1961 cls, 

1962 dialect: Dialect, 

1963 mapper: Mapper[Any], 

1964 *, 

1965 is_multitable: bool = False, 

1966 is_update_from: bool = False, 

1967 is_delete_using: bool = False, 

1968 is_executemany: bool = False, 

1969 ) -> bool: 

1970 # normal answer for "should we use RETURNING" at all. 

1971 normal_answer = ( 

1972 dialect.delete_returning and mapper.local_table.implicit_returning 

1973 ) 

1974 if not normal_answer: 

1975 return False 

1976 

1977 # now get into special workarounds because MariaDB supports 

1978 # DELETE...RETURNING but not DELETE...USING...RETURNING. 

1979 if is_delete_using: 

1980 # is_delete_using hint was passed. use 

1981 # additional dialect feature (True for PG, False for MariaDB) 

1982 return dialect.delete_returning_multifrom 

1983 

1984 elif is_multitable and not dialect.delete_returning_multifrom: 

1985 # is_delete_using hint was not passed, but we determined 

1986 # at compile time that this is in fact a DELETE..USING. 

1987 # it's too late to continue since we did not pre-SELECT. 

1988 # raise that we need that hint up front. 

1989 

1990 raise sa_exc.CompileError( 

1991 f'Dialect "{dialect.name}" does not support RETURNING ' 

1992 "with DELETE..USING; for synchronize_session='fetch', " 

1993 "please add the additional execution option " 

1994 "'is_delete_using=True' to the statement to indicate that " 

1995 "a separate SELECT should be used for this backend." 

1996 ) 

1997 

1998 return True 

1999 

2000 @classmethod 

2001 def _do_post_synchronize_evaluate( 

2002 cls, session, statement, result, update_options 

2003 ): 

2004 matched_objects = cls._get_matched_objects_on_criteria( 

2005 update_options, 

2006 session.identity_map.all_states(), 

2007 ) 

2008 

2009 to_delete = [] 

2010 

2011 for _, state, dict_, is_partially_expired in matched_objects: 

2012 if is_partially_expired: 

2013 state._expire(dict_, session.identity_map._modified) 

2014 else: 

2015 to_delete.append(state) 

2016 

2017 if to_delete: 

2018 session._remove_newly_deleted(to_delete) 

2019 

2020 @classmethod 

2021 def _do_post_synchronize_fetch( 

2022 cls, session, statement, result, update_options 

2023 ): 

2024 target_mapper = update_options._subject_mapper 

2025 

2026 returned_defaults_rows = result.returned_defaults_rows 

2027 

2028 if returned_defaults_rows: 

2029 pk_rows = cls._interpret_returning_rows( 

2030 target_mapper, returned_defaults_rows 

2031 ) 

2032 

2033 matched_rows = [ 

2034 tuple(row) + (update_options._identity_token,) 

2035 for row in pk_rows 

2036 ] 

2037 else: 

2038 matched_rows = update_options._matched_rows 

2039 

2040 for row in matched_rows: 

2041 primary_key = row[0:-1] 

2042 identity_token = row[-1] 

2043 

2044 # TODO: inline this and call remove_newly_deleted 

2045 # once 

2046 identity_key = target_mapper.identity_key_from_primary_key( 

2047 list(primary_key), 

2048 identity_token=identity_token, 

2049 ) 

2050 if identity_key in session.identity_map: 

2051 session._remove_newly_deleted( 

2052 [ 

2053 attributes.instance_state( 

2054 session.identity_map[identity_key] 

2055 ) 

2056 ] 

2057 )