Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/sqlalchemy/orm/bulk_persistence.py: 22%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

734 statements  

1# orm/bulk_persistence.py 

2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7# mypy: ignore-errors 

8 

9 

10"""additional ORM persistence classes related to "bulk" operations, 

11specifically outside of the flush() process. 

12 

13""" 

14 

15from __future__ import annotations 

16 

17from typing import Any 

18from typing import cast 

19from typing import Dict 

20from typing import Iterable 

21from typing import Optional 

22from typing import overload 

23from typing import TYPE_CHECKING 

24from typing import TypeVar 

25from typing import Union 

26 

27from . import attributes 

28from . import context 

29from . import evaluator 

30from . import exc as orm_exc 

31from . import loading 

32from . import persistence 

33from .base import NO_VALUE 

34from .context import AbstractORMCompileState 

35from .context import FromStatement 

36from .context import ORMFromStatementCompileState 

37from .context import QueryContext 

38from .. import exc as sa_exc 

39from .. import util 

40from ..engine import Dialect 

41from ..engine import result as _result 

42from ..sql import coercions 

43from ..sql import dml 

44from ..sql import expression 

45from ..sql import roles 

46from ..sql import select 

47from ..sql import sqltypes 

48from ..sql.base import _entity_namespace_key 

49from ..sql.base import CompileState 

50from ..sql.base import Options 

51from ..sql.dml import DeleteDMLState 

52from ..sql.dml import InsertDMLState 

53from ..sql.dml import UpdateDMLState 

54from ..util import EMPTY_DICT 

55from ..util.typing import Literal 

56from ..util.typing import TupleAny 

57from ..util.typing import Unpack 

58 

59if TYPE_CHECKING: 

60 from ._typing import DMLStrategyArgument 

61 from ._typing import OrmExecuteOptionsParameter 

62 from ._typing import SynchronizeSessionArgument 

63 from .mapper import Mapper 

64 from .session import _BindArguments 

65 from .session import ORMExecuteState 

66 from .session import Session 

67 from .session import SessionTransaction 

68 from .state import InstanceState 

69 from ..engine import Connection 

70 from ..engine import cursor 

71 from ..engine.interfaces import _CoreAnyExecuteParams 

72 

73_O = TypeVar("_O", bound=object) 

74 

75 

76@overload 

77def _bulk_insert( 

78 mapper: Mapper[_O], 

79 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

80 session_transaction: SessionTransaction, 

81 *, 

82 isstates: bool, 

83 return_defaults: bool, 

84 render_nulls: bool, 

85 use_orm_insert_stmt: Literal[None] = ..., 

86 execution_options: Optional[OrmExecuteOptionsParameter] = ..., 

87) -> None: ... 

88 

89 

90@overload 

91def _bulk_insert( 

92 mapper: Mapper[_O], 

93 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

94 session_transaction: SessionTransaction, 

95 *, 

96 isstates: bool, 

97 return_defaults: bool, 

98 render_nulls: bool, 

99 use_orm_insert_stmt: Optional[dml.Insert] = ..., 

100 execution_options: Optional[OrmExecuteOptionsParameter] = ..., 

101) -> cursor.CursorResult[Any]: ... 

102 

103 

104def _bulk_insert( 

105 mapper: Mapper[_O], 

106 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

107 session_transaction: SessionTransaction, 

108 *, 

109 isstates: bool, 

110 return_defaults: bool, 

111 render_nulls: bool, 

112 use_orm_insert_stmt: Optional[dml.Insert] = None, 

113 execution_options: Optional[OrmExecuteOptionsParameter] = None, 

114) -> Optional[cursor.CursorResult[Any]]: 

115 base_mapper = mapper.base_mapper 

116 

117 if session_transaction.session.connection_callable: 

118 raise NotImplementedError( 

119 "connection_callable / per-instance sharding " 

120 "not supported in bulk_insert()" 

121 ) 

122 

123 if isstates: 

124 if TYPE_CHECKING: 

125 mappings = cast(Iterable[InstanceState[_O]], mappings) 

126 

127 if return_defaults: 

128 # list of states allows us to attach .key for return_defaults case 

129 states = [(state, state.dict) for state in mappings] 

130 mappings = [dict_ for (state, dict_) in states] 

131 else: 

132 mappings = [state.dict for state in mappings] 

133 else: 

134 if TYPE_CHECKING: 

135 mappings = cast(Iterable[Dict[str, Any]], mappings) 

136 

137 if return_defaults: 

138 # use dictionaries given, so that newly populated defaults 

139 # can be delivered back to the caller (see #11661). This is **not** 

140 # compatible with other use cases such as a session-executed 

141 # insert() construct, as this will confuse the case of 

142 # insert-per-subclass for joined inheritance cases (see 

143 # test_bulk_statements.py::BulkDMLReturningJoinedInhTest). 

144 # 

145 # So in this conditional, we have **only** called 

146 # session.bulk_insert_mappings() which does not have this 

147 # requirement 

148 mappings = list(mappings) 

149 else: 

150 # for all other cases we need to establish a local dictionary 

151 # so that the incoming dictionaries aren't mutated 

152 mappings = [dict(m) for m in mappings] 

153 _expand_composites(mapper, mappings) 

154 

155 connection = session_transaction.connection(base_mapper) 

156 

157 return_result: Optional[cursor.CursorResult[Any]] = None 

158 

159 mappers_to_run = [ 

160 (table, mp) 

161 for table, mp in base_mapper._sorted_tables.items() 

162 if table in mapper._pks_by_table 

163 ] 

164 

165 if return_defaults: 

166 # not used by new-style bulk inserts, only used for legacy 

167 bookkeeping = True 

168 elif len(mappers_to_run) > 1: 

169 # if we have more than one table, mapper to run where we will be 

170 # either horizontally splicing, or copying values between tables, 

171 # we need the "bookkeeping" / deterministic returning order 

172 bookkeeping = True 

173 else: 

174 bookkeeping = False 

175 

176 for table, super_mapper in mappers_to_run: 

177 # find bindparams in the statement. For bulk, we don't really know if 

178 # a key in the params applies to a different table since we are 

179 # potentially inserting for multiple tables here; looking at the 

180 # bindparam() is a lot more direct. in most cases this will 

181 # use _generate_cache_key() which is memoized, although in practice 

182 # the ultimate statement that's executed is probably not the same 

183 # object so that memoization might not matter much. 

184 extra_bp_names = ( 

185 [ 

186 b.key 

187 for b in use_orm_insert_stmt._get_embedded_bindparams() 

188 if b.key in mappings[0] 

189 ] 

190 if use_orm_insert_stmt is not None 

191 else () 

192 ) 

193 

194 records = ( 

195 ( 

196 None, 

197 state_dict, 

198 params, 

199 mapper, 

200 connection, 

201 value_params, 

202 has_all_pks, 

203 has_all_defaults, 

204 ) 

205 for ( 

206 state, 

207 state_dict, 

208 params, 

209 mp, 

210 conn, 

211 value_params, 

212 has_all_pks, 

213 has_all_defaults, 

214 ) in persistence._collect_insert_commands( 

215 table, 

216 ((None, mapping, mapper, connection) for mapping in mappings), 

217 bulk=True, 

218 return_defaults=bookkeeping, 

219 render_nulls=render_nulls, 

220 include_bulk_keys=extra_bp_names, 

221 ) 

222 ) 

223 

224 result = persistence._emit_insert_statements( 

225 base_mapper, 

226 None, 

227 super_mapper, 

228 table, 

229 records, 

230 bookkeeping=bookkeeping, 

231 use_orm_insert_stmt=use_orm_insert_stmt, 

232 execution_options=execution_options, 

233 ) 

234 if use_orm_insert_stmt is not None: 

235 if not use_orm_insert_stmt._returning or return_result is None: 

236 return_result = result 

237 elif result.returns_rows: 

238 assert bookkeeping 

239 return_result = return_result.splice_horizontally(result) 

240 

241 if return_defaults and isstates: 

242 identity_cls = mapper._identity_class 

243 identity_props = [p.key for p in mapper._identity_key_props] 

244 for state, dict_ in states: 

245 state.key = ( 

246 identity_cls, 

247 tuple([dict_[key] for key in identity_props]), 

248 None, 

249 ) 

250 

251 if use_orm_insert_stmt is not None: 

252 assert return_result is not None 

253 return return_result 

254 

255 

256@overload 

257def _bulk_update( 

258 mapper: Mapper[Any], 

259 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

260 session_transaction: SessionTransaction, 

261 *, 

262 isstates: bool, 

263 update_changed_only: bool, 

264 use_orm_update_stmt: Literal[None] = ..., 

265 enable_check_rowcount: bool = True, 

266) -> None: ... 

267 

268 

269@overload 

270def _bulk_update( 

271 mapper: Mapper[Any], 

272 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

273 session_transaction: SessionTransaction, 

274 *, 

275 isstates: bool, 

276 update_changed_only: bool, 

277 use_orm_update_stmt: Optional[dml.Update] = ..., 

278 enable_check_rowcount: bool = True, 

279) -> _result.Result[Unpack[TupleAny]]: ... 

280 

281 

282def _bulk_update( 

283 mapper: Mapper[Any], 

284 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], 

285 session_transaction: SessionTransaction, 

286 *, 

287 isstates: bool, 

288 update_changed_only: bool, 

289 use_orm_update_stmt: Optional[dml.Update] = None, 

290 enable_check_rowcount: bool = True, 

291) -> Optional[_result.Result[Unpack[TupleAny]]]: 

292 base_mapper = mapper.base_mapper 

293 

294 search_keys = mapper._primary_key_propkeys 

295 if mapper._version_id_prop: 

296 search_keys = {mapper._version_id_prop.key}.union(search_keys) 

297 

298 def _changed_dict(mapper, state): 

299 return { 

300 k: v 

301 for k, v in state.dict.items() 

302 if k in state.committed_state or k in search_keys 

303 } 

304 

305 if isstates: 

306 if update_changed_only: 

307 mappings = [_changed_dict(mapper, state) for state in mappings] 

308 else: 

309 mappings = [state.dict for state in mappings] 

310 else: 

311 mappings = [dict(m) for m in mappings] 

312 _expand_composites(mapper, mappings) 

313 

314 if session_transaction.session.connection_callable: 

315 raise NotImplementedError( 

316 "connection_callable / per-instance sharding " 

317 "not supported in bulk_update()" 

318 ) 

319 

320 connection = session_transaction.connection(base_mapper) 

321 

322 # find bindparams in the statement. see _bulk_insert for similar 

323 # notes for the insert case 

324 extra_bp_names = ( 

325 [ 

326 b.key 

327 for b in use_orm_update_stmt._get_embedded_bindparams() 

328 if b.key in mappings[0] 

329 ] 

330 if use_orm_update_stmt is not None 

331 else () 

332 ) 

333 

334 for table, super_mapper in base_mapper._sorted_tables.items(): 

335 if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: 

336 continue 

337 

338 records = persistence._collect_update_commands( 

339 None, 

340 table, 

341 ( 

342 ( 

343 None, 

344 mapping, 

345 mapper, 

346 connection, 

347 ( 

348 mapping[mapper._version_id_prop.key] 

349 if mapper._version_id_prop 

350 else None 

351 ), 

352 ) 

353 for mapping in mappings 

354 ), 

355 bulk=True, 

356 use_orm_update_stmt=use_orm_update_stmt, 

357 include_bulk_keys=extra_bp_names, 

358 ) 

359 persistence._emit_update_statements( 

360 base_mapper, 

361 None, 

362 super_mapper, 

363 table, 

364 records, 

365 bookkeeping=False, 

366 use_orm_update_stmt=use_orm_update_stmt, 

367 enable_check_rowcount=enable_check_rowcount, 

368 ) 

369 

370 if use_orm_update_stmt is not None: 

371 return _result.null_result() 

372 

373 

374def _expand_composites(mapper, mappings): 

375 composite_attrs = mapper.composites 

376 if not composite_attrs: 

377 return 

378 

379 composite_keys = set(composite_attrs.keys()) 

380 populators = { 

381 key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() 

382 for key in composite_keys 

383 } 

384 for mapping in mappings: 

385 for key in composite_keys.intersection(mapping): 

386 populators[key](mapping) 

387 

388 

389class ORMDMLState(AbstractORMCompileState): 

390 is_dml_returning = True 

391 from_statement_ctx: Optional[ORMFromStatementCompileState] = None 

392 

393 @classmethod 

394 def _get_orm_crud_kv_pairs( 

395 cls, mapper, statement, kv_iterator, needs_to_be_cacheable 

396 ): 

397 core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs 

398 

399 for k, v in kv_iterator: 

400 k = coercions.expect(roles.DMLColumnRole, k) 

401 

402 if isinstance(k, str): 

403 desc = _entity_namespace_key(mapper, k, default=NO_VALUE) 

404 if desc is NO_VALUE: 

405 yield ( 

406 coercions.expect(roles.DMLColumnRole, k), 

407 ( 

408 coercions.expect( 

409 roles.ExpressionElementRole, 

410 v, 

411 type_=sqltypes.NullType(), 

412 is_crud=True, 

413 ) 

414 if needs_to_be_cacheable 

415 else v 

416 ), 

417 ) 

418 else: 

419 yield from core_get_crud_kv_pairs( 

420 statement, 

421 desc._bulk_update_tuples(v), 

422 needs_to_be_cacheable, 

423 ) 

424 elif "entity_namespace" in k._annotations: 

425 k_anno = k._annotations 

426 attr = _entity_namespace_key( 

427 k_anno["entity_namespace"], k_anno["proxy_key"] 

428 ) 

429 yield from core_get_crud_kv_pairs( 

430 statement, 

431 attr._bulk_update_tuples(v), 

432 needs_to_be_cacheable, 

433 ) 

434 else: 

435 yield ( 

436 k, 

437 ( 

438 v 

439 if not needs_to_be_cacheable 

440 else coercions.expect( 

441 roles.ExpressionElementRole, 

442 v, 

443 type_=sqltypes.NullType(), 

444 is_crud=True, 

445 ) 

446 ), 

447 ) 

448 

449 @classmethod 

450 def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): 

451 plugin_subject = statement._propagate_attrs["plugin_subject"] 

452 

453 if not plugin_subject or not plugin_subject.mapper: 

454 return UpdateDMLState._get_multi_crud_kv_pairs( 

455 statement, kv_iterator 

456 ) 

457 

458 return [ 

459 dict( 

460 cls._get_orm_crud_kv_pairs( 

461 plugin_subject.mapper, statement, value_dict.items(), False 

462 ) 

463 ) 

464 for value_dict in kv_iterator 

465 ] 

466 

467 @classmethod 

468 def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable): 

469 assert ( 

470 needs_to_be_cacheable 

471 ), "no test coverage for needs_to_be_cacheable=False" 

472 

473 plugin_subject = statement._propagate_attrs["plugin_subject"] 

474 

475 if not plugin_subject or not plugin_subject.mapper: 

476 return UpdateDMLState._get_crud_kv_pairs( 

477 statement, kv_iterator, needs_to_be_cacheable 

478 ) 

479 

480 return list( 

481 cls._get_orm_crud_kv_pairs( 

482 plugin_subject.mapper, 

483 statement, 

484 kv_iterator, 

485 needs_to_be_cacheable, 

486 ) 

487 ) 

488 

489 @classmethod 

490 def get_entity_description(cls, statement): 

491 ext_info = statement.table._annotations["parententity"] 

492 mapper = ext_info.mapper 

493 if ext_info.is_aliased_class: 

494 _label_name = ext_info.name 

495 else: 

496 _label_name = mapper.class_.__name__ 

497 

498 return { 

499 "name": _label_name, 

500 "type": mapper.class_, 

501 "expr": ext_info.entity, 

502 "entity": ext_info.entity, 

503 "table": mapper.local_table, 

504 } 

505 

506 @classmethod 

507 def get_returning_column_descriptions(cls, statement): 

508 def _ent_for_col(c): 

509 return c._annotations.get("parententity", None) 

510 

511 def _attr_for_col(c, ent): 

512 if ent is None: 

513 return c 

514 proxy_key = c._annotations.get("proxy_key", None) 

515 if not proxy_key: 

516 return c 

517 else: 

518 return getattr(ent.entity, proxy_key, c) 

519 

520 return [ 

521 { 

522 "name": c.key, 

523 "type": c.type, 

524 "expr": _attr_for_col(c, ent), 

525 "aliased": ent.is_aliased_class, 

526 "entity": ent.entity, 

527 } 

528 for c, ent in [ 

529 (c, _ent_for_col(c)) for c in statement._all_selected_columns 

530 ] 

531 ] 

532 

533 def _setup_orm_returning( 

534 self, 

535 compiler, 

536 orm_level_statement, 

537 dml_level_statement, 

538 dml_mapper, 

539 *, 

540 use_supplemental_cols=True, 

541 ): 

542 """establish ORM column handlers for an INSERT, UPDATE, or DELETE 

543 which uses explicit returning(). 

544 

545 called within compilation level create_for_statement. 

546 

547 The _return_orm_returning() method then receives the Result 

548 after the statement was executed, and applies ORM loading to the 

549 state that we first established here. 

550 

551 """ 

552 

553 if orm_level_statement._returning: 

554 fs = FromStatement( 

555 orm_level_statement._returning, 

556 dml_level_statement, 

557 _adapt_on_names=False, 

558 ) 

559 fs = fs.execution_options(**orm_level_statement._execution_options) 

560 fs = fs.options(*orm_level_statement._with_options) 

561 self.select_statement = fs 

562 self.from_statement_ctx = fsc = ( 

563 ORMFromStatementCompileState.create_for_statement(fs, compiler) 

564 ) 

565 fsc.setup_dml_returning_compile_state(dml_mapper) 

566 

567 dml_level_statement = dml_level_statement._generate() 

568 dml_level_statement._returning = () 

569 

570 cols_to_return = [c for c in fsc.primary_columns if c is not None] 

571 

572 # since we are splicing result sets together, make sure there 

573 # are columns of some kind returned in each result set 

574 if not cols_to_return: 

575 cols_to_return.extend(dml_mapper.primary_key) 

576 

577 if use_supplemental_cols: 

578 dml_level_statement = dml_level_statement.return_defaults( 

579 # this is a little weird looking, but by passing 

580 # primary key as the main list of cols, this tells 

581 # return_defaults to omit server-default cols (and 

582 # actually all cols, due to some weird thing we should 

583 # clean up in crud.py). 

584 # Since we have cols_to_return, just return what we asked 

585 # for (plus primary key, which ORM persistence needs since 

586 # we likely set bookkeeping=True here, which is another 

587 # whole thing...). We dont want to clutter the 

588 # statement up with lots of other cols the user didn't 

589 # ask for. see #9685 

590 *dml_mapper.primary_key, 

591 supplemental_cols=cols_to_return, 

592 ) 

593 else: 

594 dml_level_statement = dml_level_statement.returning( 

595 *cols_to_return 

596 ) 

597 

598 return dml_level_statement 

599 

600 @classmethod 

601 def _return_orm_returning( 

602 cls, 

603 session, 

604 statement, 

605 params, 

606 execution_options, 

607 bind_arguments, 

608 result, 

609 ): 

610 execution_context = result.context 

611 compile_state = execution_context.compiled.compile_state 

612 

613 if ( 

614 compile_state.from_statement_ctx 

615 and not compile_state.from_statement_ctx.compile_options._is_star 

616 ): 

617 load_options = execution_options.get( 

618 "_sa_orm_load_options", QueryContext.default_load_options 

619 ) 

620 

621 querycontext = QueryContext( 

622 compile_state.from_statement_ctx, 

623 compile_state.select_statement, 

624 params, 

625 session, 

626 load_options, 

627 execution_options, 

628 bind_arguments, 

629 ) 

630 return loading.instances(result, querycontext) 

631 else: 

632 return result 

633 

634 

635class BulkUDCompileState(ORMDMLState): 

636 class default_update_options(Options): 

637 _dml_strategy: DMLStrategyArgument = "auto" 

638 _synchronize_session: SynchronizeSessionArgument = "auto" 

639 _can_use_returning: bool = False 

640 _is_delete_using: bool = False 

641 _is_update_from: bool = False 

642 _autoflush: bool = True 

643 _subject_mapper: Optional[Mapper[Any]] = None 

644 _resolved_values = EMPTY_DICT 

645 _eval_condition = None 

646 _matched_rows = None 

647 _identity_token = None 

648 

649 @classmethod 

650 def can_use_returning( 

651 cls, 

652 dialect: Dialect, 

653 mapper: Mapper[Any], 

654 *, 

655 is_multitable: bool = False, 

656 is_update_from: bool = False, 

657 is_delete_using: bool = False, 

658 is_executemany: bool = False, 

659 ) -> bool: 

660 raise NotImplementedError() 

661 

662 @classmethod 

663 def orm_pre_session_exec( 

664 cls, 

665 session, 

666 statement, 

667 params, 

668 execution_options, 

669 bind_arguments, 

670 is_pre_event, 

671 ): 

672 ( 

673 update_options, 

674 execution_options, 

675 ) = BulkUDCompileState.default_update_options.from_execution_options( 

676 "_sa_orm_update_options", 

677 { 

678 "synchronize_session", 

679 "autoflush", 

680 "identity_token", 

681 "is_delete_using", 

682 "is_update_from", 

683 "dml_strategy", 

684 }, 

685 execution_options, 

686 statement._execution_options, 

687 ) 

688 bind_arguments["clause"] = statement 

689 try: 

690 plugin_subject = statement._propagate_attrs["plugin_subject"] 

691 except KeyError: 

692 assert False, "statement had 'orm' plugin but no plugin_subject" 

693 else: 

694 if plugin_subject: 

695 bind_arguments["mapper"] = plugin_subject.mapper 

696 update_options += {"_subject_mapper": plugin_subject.mapper} 

697 

698 if "parententity" not in statement.table._annotations: 

699 update_options += {"_dml_strategy": "core_only"} 

700 elif not isinstance(params, list): 

701 if update_options._dml_strategy == "auto": 

702 update_options += {"_dml_strategy": "orm"} 

703 elif update_options._dml_strategy == "bulk": 

704 raise sa_exc.InvalidRequestError( 

705 'Can\'t use "bulk" ORM insert strategy without ' 

706 "passing separate parameters" 

707 ) 

708 else: 

709 if update_options._dml_strategy == "auto": 

710 update_options += {"_dml_strategy": "bulk"} 

711 

712 sync = update_options._synchronize_session 

713 if sync is not None: 

714 if sync not in ("auto", "evaluate", "fetch", False): 

715 raise sa_exc.ArgumentError( 

716 "Valid strategies for session synchronization " 

717 "are 'auto', 'evaluate', 'fetch', False" 

718 ) 

719 if update_options._dml_strategy == "bulk" and sync == "fetch": 

720 raise sa_exc.InvalidRequestError( 

721 "The 'fetch' synchronization strategy is not available " 

722 "for 'bulk' ORM updates (i.e. multiple parameter sets)" 

723 ) 

724 

725 if not is_pre_event: 

726 if update_options._autoflush: 

727 session._autoflush() 

728 

729 if update_options._dml_strategy == "orm": 

730 if update_options._synchronize_session == "auto": 

731 update_options = cls._do_pre_synchronize_auto( 

732 session, 

733 statement, 

734 params, 

735 execution_options, 

736 bind_arguments, 

737 update_options, 

738 ) 

739 elif update_options._synchronize_session == "evaluate": 

740 update_options = cls._do_pre_synchronize_evaluate( 

741 session, 

742 statement, 

743 params, 

744 execution_options, 

745 bind_arguments, 

746 update_options, 

747 ) 

748 elif update_options._synchronize_session == "fetch": 

749 update_options = cls._do_pre_synchronize_fetch( 

750 session, 

751 statement, 

752 params, 

753 execution_options, 

754 bind_arguments, 

755 update_options, 

756 ) 

757 elif update_options._dml_strategy == "bulk": 

758 if update_options._synchronize_session == "auto": 

759 update_options += {"_synchronize_session": "evaluate"} 

760 

761 # indicators from the "pre exec" step that are then 

762 # added to the DML statement, which will also be part of the cache 

763 # key. The compile level create_for_statement() method will then 

764 # consume these at compiler time. 

765 statement = statement._annotate( 

766 { 

767 "synchronize_session": update_options._synchronize_session, 

768 "is_delete_using": update_options._is_delete_using, 

769 "is_update_from": update_options._is_update_from, 

770 "dml_strategy": update_options._dml_strategy, 

771 "can_use_returning": update_options._can_use_returning, 

772 } 

773 ) 

774 

775 return ( 

776 statement, 

777 util.immutabledict(execution_options).union( 

778 {"_sa_orm_update_options": update_options} 

779 ), 

780 ) 

781 

782 @classmethod 

783 def orm_setup_cursor_result( 

784 cls, 

785 session, 

786 statement, 

787 params, 

788 execution_options, 

789 bind_arguments, 

790 result, 

791 ): 

792 # this stage of the execution is called after the 

793 # do_orm_execute event hook. meaning for an extension like 

794 # horizontal sharding, this step happens *within* the horizontal 

795 # sharding event handler which calls session.execute() re-entrantly 

796 # and will occur for each backend individually. 

797 # the sharding extension then returns its own merged result from the 

798 # individual ones we return here. 

799 

800 update_options = execution_options["_sa_orm_update_options"] 

801 if update_options._dml_strategy == "orm": 

802 if update_options._synchronize_session == "evaluate": 

803 cls._do_post_synchronize_evaluate( 

804 session, statement, result, update_options 

805 ) 

806 elif update_options._synchronize_session == "fetch": 

807 cls._do_post_synchronize_fetch( 

808 session, statement, result, update_options 

809 ) 

810 elif update_options._dml_strategy == "bulk": 

811 if update_options._synchronize_session == "evaluate": 

812 cls._do_post_synchronize_bulk_evaluate( 

813 session, params, result, update_options 

814 ) 

815 return result 

816 

817 return cls._return_orm_returning( 

818 session, 

819 statement, 

820 params, 

821 execution_options, 

822 bind_arguments, 

823 result, 

824 ) 

825 

826 @classmethod 

827 def _adjust_for_extra_criteria(cls, global_attributes, ext_info): 

828 """Apply extra criteria filtering. 

829 

830 For all distinct single-table-inheritance mappers represented in the 

831 table being updated or deleted, produce additional WHERE criteria such 

832 that only the appropriate subtypes are selected from the total results. 

833 

834 Additionally, add WHERE criteria originating from LoaderCriteriaOptions 

835 collected from the statement. 

836 

837 """ 

838 

839 return_crit = () 

840 

841 adapter = ext_info._adapter if ext_info.is_aliased_class else None 

842 

843 if ( 

844 "additional_entity_criteria", 

845 ext_info.mapper, 

846 ) in global_attributes: 

847 return_crit += tuple( 

848 ae._resolve_where_criteria(ext_info) 

849 for ae in global_attributes[ 

850 ("additional_entity_criteria", ext_info.mapper) 

851 ] 

852 if ae.include_aliases or ae.entity is ext_info 

853 ) 

854 

855 if ext_info.mapper._single_table_criterion is not None: 

856 return_crit += (ext_info.mapper._single_table_criterion,) 

857 

858 if adapter: 

859 return_crit = tuple(adapter.traverse(crit) for crit in return_crit) 

860 

861 return return_crit 

862 

863 @classmethod 

864 def _interpret_returning_rows(cls, mapper, rows): 

865 """translate from local inherited table columns to base mapper 

866 primary key columns. 

867 

868 Joined inheritance mappers always establish the primary key in terms of 

869 the base table. When we UPDATE a sub-table, we can only get 

870 RETURNING for the sub-table's columns. 

871 

872 Here, we create a lookup from the local sub table's primary key 

873 columns to the base table PK columns so that we can get identity 

874 key values from RETURNING that's against the joined inheritance 

875 sub-table. 

876 

877 the complexity here is to support more than one level deep of 

878 inheritance, where we have to link columns to each other across 

879 the inheritance hierarchy. 

880 

881 """ 

882 

883 if mapper.local_table is not mapper.base_mapper.local_table: 

884 return rows 

885 

886 # this starts as a mapping of 

887 # local_pk_col: local_pk_col. 

888 # we will then iteratively rewrite the "value" of the dict with 

889 # each successive superclass column 

890 local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key} 

891 

892 for mp in mapper.iterate_to_root(): 

893 if mp.inherits is None: 

894 break 

895 elif mp.local_table is mp.inherits.local_table: 

896 continue 

897 

898 t_to_e = dict(mp._table_to_equated[mp.inherits.local_table]) 

899 col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]} 

900 for pk, super_ in local_pk_to_base_pk.items(): 

901 local_pk_to_base_pk[pk] = col_to_col[super_] 

902 

903 lookup = { 

904 local_pk_to_base_pk[lpk]: idx 

905 for idx, lpk in enumerate(mapper.local_table.primary_key) 

906 } 

907 primary_key_convert = [ 

908 lookup[bpk] for bpk in mapper.base_mapper.primary_key 

909 ] 

910 return [tuple(row[idx] for idx in primary_key_convert) for row in rows] 

911 

912 @classmethod 

913 def _get_matched_objects_on_criteria(cls, update_options, states): 

914 mapper = update_options._subject_mapper 

915 eval_condition = update_options._eval_condition 

916 

917 raw_data = [ 

918 (state.obj(), state, state.dict) 

919 for state in states 

920 if state.mapper.isa(mapper) and not state.expired 

921 ] 

922 

923 identity_token = update_options._identity_token 

924 if identity_token is not None: 

925 raw_data = [ 

926 (obj, state, dict_) 

927 for obj, state, dict_ in raw_data 

928 if state.identity_token == identity_token 

929 ] 

930 

931 result = [] 

932 for obj, state, dict_ in raw_data: 

933 evaled_condition = eval_condition(obj) 

934 

935 # caution: don't use "in ()" or == here, _EXPIRE_OBJECT 

936 # evaluates as True for all comparisons 

937 if ( 

938 evaled_condition is True 

939 or evaled_condition is evaluator._EXPIRED_OBJECT 

940 ): 

941 result.append( 

942 ( 

943 obj, 

944 state, 

945 dict_, 

946 evaled_condition is evaluator._EXPIRED_OBJECT, 

947 ) 

948 ) 

949 return result 

950 

951 @classmethod 

952 def _eval_condition_from_statement(cls, update_options, statement): 

953 mapper = update_options._subject_mapper 

954 target_cls = mapper.class_ 

955 

956 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) 

957 crit = () 

958 if statement._where_criteria: 

959 crit += statement._where_criteria 

960 

961 global_attributes = {} 

962 for opt in statement._with_options: 

963 if opt._is_criteria_option: 

964 opt.get_global_criteria(global_attributes) 

965 

966 if global_attributes: 

967 crit += cls._adjust_for_extra_criteria(global_attributes, mapper) 

968 

969 if crit: 

970 eval_condition = evaluator_compiler.process(*crit) 

971 else: 

972 # workaround for mypy https://github.com/python/mypy/issues/14027 

973 def _eval_condition(obj): 

974 return True 

975 

976 eval_condition = _eval_condition 

977 

978 return eval_condition 

979 

980 @classmethod 

981 def _do_pre_synchronize_auto( 

982 cls, 

983 session, 

984 statement, 

985 params, 

986 execution_options, 

987 bind_arguments, 

988 update_options, 

989 ): 

990 """setup auto sync strategy 

991 

992 

993 "auto" checks if we can use "evaluate" first, then falls back 

994 to "fetch" 

995 

996 evaluate is vastly more efficient for the common case 

997 where session is empty, only has a few objects, and the UPDATE 

998 statement can potentially match thousands/millions of rows. 

999 

1000 OTOH more complex criteria that fails to work with "evaluate" 

1001 we would hope usually correlates with fewer net rows. 

1002 

1003 """ 

1004 

1005 try: 

1006 eval_condition = cls._eval_condition_from_statement( 

1007 update_options, statement 

1008 ) 

1009 

1010 except evaluator.UnevaluatableError: 

1011 pass 

1012 else: 

1013 return update_options + { 

1014 "_eval_condition": eval_condition, 

1015 "_synchronize_session": "evaluate", 

1016 } 

1017 

1018 update_options += {"_synchronize_session": "fetch"} 

1019 return cls._do_pre_synchronize_fetch( 

1020 session, 

1021 statement, 

1022 params, 

1023 execution_options, 

1024 bind_arguments, 

1025 update_options, 

1026 ) 

1027 

1028 @classmethod 

1029 def _do_pre_synchronize_evaluate( 

1030 cls, 

1031 session, 

1032 statement, 

1033 params, 

1034 execution_options, 

1035 bind_arguments, 

1036 update_options, 

1037 ): 

1038 try: 

1039 eval_condition = cls._eval_condition_from_statement( 

1040 update_options, statement 

1041 ) 

1042 

1043 except evaluator.UnevaluatableError as err: 

1044 raise sa_exc.InvalidRequestError( 

1045 'Could not evaluate current criteria in Python: "%s". ' 

1046 "Specify 'fetch' or False for the " 

1047 "synchronize_session execution option." % err 

1048 ) from err 

1049 

1050 return update_options + { 

1051 "_eval_condition": eval_condition, 

1052 } 

1053 

1054 @classmethod 

1055 def _get_resolved_values(cls, mapper, statement): 

1056 if statement._multi_values: 

1057 return [] 

1058 elif statement._ordered_values: 

1059 return list(statement._ordered_values) 

1060 elif statement._values: 

1061 return list(statement._values.items()) 

1062 else: 

1063 return [] 

1064 

1065 @classmethod 

1066 def _resolved_keys_as_propnames(cls, mapper, resolved_values): 

1067 values = [] 

1068 for k, v in resolved_values: 

1069 if mapper and isinstance(k, expression.ColumnElement): 

1070 try: 

1071 attr = mapper._columntoproperty[k] 

1072 except orm_exc.UnmappedColumnError: 

1073 pass 

1074 else: 

1075 values.append((attr.key, v)) 

1076 else: 

1077 raise sa_exc.InvalidRequestError( 

1078 "Attribute name not found, can't be " 

1079 "synchronized back to objects: %r" % k 

1080 ) 

1081 return values 

1082 

1083 @classmethod 

1084 def _do_pre_synchronize_fetch( 

1085 cls, 

1086 session, 

1087 statement, 

1088 params, 

1089 execution_options, 

1090 bind_arguments, 

1091 update_options, 

1092 ): 

1093 mapper = update_options._subject_mapper 

1094 

1095 select_stmt = ( 

1096 select(*(mapper.primary_key + (mapper.select_identity_token,))) 

1097 .select_from(mapper) 

1098 .options(*statement._with_options) 

1099 ) 

1100 select_stmt._where_criteria = statement._where_criteria 

1101 

1102 # conditionally run the SELECT statement for pre-fetch, testing the 

1103 # "bind" for if we can use RETURNING or not using the do_orm_execute 

1104 # event. If RETURNING is available, the do_orm_execute event 

1105 # will cancel the SELECT from being actually run. 

1106 # 

1107 # The way this is organized seems strange, why don't we just 

1108 # call can_use_returning() before invoking the statement and get 

1109 # answer?, why does this go through the whole execute phase using an 

1110 # event? Answer: because we are integrating with extensions such 

1111 # as the horizontal sharding extention that "multiplexes" an individual 

1112 # statement run through multiple engines, and it uses 

1113 # do_orm_execute() to do that. 

1114 

1115 can_use_returning = None 

1116 

1117 def skip_for_returning(orm_context: ORMExecuteState) -> Any: 

1118 bind = orm_context.session.get_bind(**orm_context.bind_arguments) 

1119 nonlocal can_use_returning 

1120 

1121 per_bind_result = cls.can_use_returning( 

1122 bind.dialect, 

1123 mapper, 

1124 is_update_from=update_options._is_update_from, 

1125 is_delete_using=update_options._is_delete_using, 

1126 is_executemany=orm_context.is_executemany, 

1127 ) 

1128 

1129 if can_use_returning is not None: 

1130 if can_use_returning != per_bind_result: 

1131 raise sa_exc.InvalidRequestError( 

1132 "For synchronize_session='fetch', can't mix multiple " 

1133 "backends where some support RETURNING and others " 

1134 "don't" 

1135 ) 

1136 elif orm_context.is_executemany and not per_bind_result: 

1137 raise sa_exc.InvalidRequestError( 

1138 "For synchronize_session='fetch', can't use multiple " 

1139 "parameter sets in ORM mode, which this backend does not " 

1140 "support with RETURNING" 

1141 ) 

1142 else: 

1143 can_use_returning = per_bind_result 

1144 

1145 if per_bind_result: 

1146 return _result.null_result() 

1147 else: 

1148 return None 

1149 

1150 result = session.execute( 

1151 select_stmt, 

1152 params, 

1153 execution_options=execution_options, 

1154 bind_arguments=bind_arguments, 

1155 _add_event=skip_for_returning, 

1156 ) 

1157 matched_rows = result.fetchall() 

1158 

1159 return update_options + { 

1160 "_matched_rows": matched_rows, 

1161 "_can_use_returning": can_use_returning, 

1162 } 

1163 

1164 

1165@CompileState.plugin_for("orm", "insert") 

1166class BulkORMInsert(ORMDMLState, InsertDMLState): 

1167 class default_insert_options(Options): 

1168 _dml_strategy: DMLStrategyArgument = "auto" 

1169 _render_nulls: bool = False 

1170 _return_defaults: bool = False 

1171 _subject_mapper: Optional[Mapper[Any]] = None 

1172 _autoflush: bool = True 

1173 _populate_existing: bool = False 

1174 

1175 select_statement: Optional[FromStatement] = None 

1176 

1177 @classmethod 

1178 def orm_pre_session_exec( 

1179 cls, 

1180 session, 

1181 statement, 

1182 params, 

1183 execution_options, 

1184 bind_arguments, 

1185 is_pre_event, 

1186 ): 

1187 ( 

1188 insert_options, 

1189 execution_options, 

1190 ) = BulkORMInsert.default_insert_options.from_execution_options( 

1191 "_sa_orm_insert_options", 

1192 {"dml_strategy", "autoflush", "populate_existing", "render_nulls"}, 

1193 execution_options, 

1194 statement._execution_options, 

1195 ) 

1196 bind_arguments["clause"] = statement 

1197 try: 

1198 plugin_subject = statement._propagate_attrs["plugin_subject"] 

1199 except KeyError: 

1200 assert False, "statement had 'orm' plugin but no plugin_subject" 

1201 else: 

1202 if plugin_subject: 

1203 bind_arguments["mapper"] = plugin_subject.mapper 

1204 insert_options += {"_subject_mapper": plugin_subject.mapper} 

1205 

1206 if not params: 

1207 if insert_options._dml_strategy == "auto": 

1208 insert_options += {"_dml_strategy": "orm"} 

1209 elif insert_options._dml_strategy == "bulk": 

1210 raise sa_exc.InvalidRequestError( 

1211 'Can\'t use "bulk" ORM insert strategy without ' 

1212 "passing separate parameters" 

1213 ) 

1214 else: 

1215 if insert_options._dml_strategy == "auto": 

1216 insert_options += {"_dml_strategy": "bulk"} 

1217 

1218 if insert_options._dml_strategy != "raw": 

1219 # for ORM object loading, like ORMContext, we have to disable 

1220 # result set adapt_to_context, because we will be generating a 

1221 # new statement with specific columns that's cached inside of 

1222 # an ORMFromStatementCompileState, which we will re-use for 

1223 # each result. 

1224 if not execution_options: 

1225 execution_options = context._orm_load_exec_options 

1226 else: 

1227 execution_options = execution_options.union( 

1228 context._orm_load_exec_options 

1229 ) 

1230 

1231 if not is_pre_event and insert_options._autoflush: 

1232 session._autoflush() 

1233 

1234 statement = statement._annotate( 

1235 {"dml_strategy": insert_options._dml_strategy} 

1236 ) 

1237 

1238 return ( 

1239 statement, 

1240 util.immutabledict(execution_options).union( 

1241 {"_sa_orm_insert_options": insert_options} 

1242 ), 

1243 ) 

1244 

1245 @classmethod 

1246 def orm_execute_statement( 

1247 cls, 

1248 session: Session, 

1249 statement: dml.Insert, 

1250 params: _CoreAnyExecuteParams, 

1251 execution_options: OrmExecuteOptionsParameter, 

1252 bind_arguments: _BindArguments, 

1253 conn: Connection, 

1254 ) -> _result.Result: 

1255 insert_options = execution_options.get( 

1256 "_sa_orm_insert_options", cls.default_insert_options 

1257 ) 

1258 

1259 if insert_options._dml_strategy not in ( 

1260 "raw", 

1261 "bulk", 

1262 "orm", 

1263 "auto", 

1264 ): 

1265 raise sa_exc.ArgumentError( 

1266 "Valid strategies for ORM insert strategy " 

1267 "are 'raw', 'orm', 'bulk', 'auto" 

1268 ) 

1269 

1270 result: _result.Result[Unpack[TupleAny]] 

1271 

1272 if insert_options._dml_strategy == "raw": 

1273 result = conn.execute( 

1274 statement, params or {}, execution_options=execution_options 

1275 ) 

1276 return result 

1277 

1278 if insert_options._dml_strategy == "bulk": 

1279 mapper = insert_options._subject_mapper 

1280 

1281 if ( 

1282 statement._post_values_clause is not None 

1283 and mapper._multiple_persistence_tables 

1284 ): 

1285 raise sa_exc.InvalidRequestError( 

1286 "bulk INSERT with a 'post values' clause " 

1287 "(typically upsert) not supported for multi-table " 

1288 f"mapper {mapper}" 

1289 ) 

1290 

1291 assert mapper is not None 

1292 assert session._transaction is not None 

1293 result = _bulk_insert( 

1294 mapper, 

1295 cast( 

1296 "Iterable[Dict[str, Any]]", 

1297 [params] if isinstance(params, dict) else params, 

1298 ), 

1299 session._transaction, 

1300 isstates=False, 

1301 return_defaults=insert_options._return_defaults, 

1302 render_nulls=insert_options._render_nulls, 

1303 use_orm_insert_stmt=statement, 

1304 execution_options=execution_options, 

1305 ) 

1306 elif insert_options._dml_strategy == "orm": 

1307 result = conn.execute( 

1308 statement, params or {}, execution_options=execution_options 

1309 ) 

1310 else: 

1311 raise AssertionError() 

1312 

1313 if not bool(statement._returning): 

1314 return result 

1315 

1316 if insert_options._populate_existing: 

1317 load_options = execution_options.get( 

1318 "_sa_orm_load_options", QueryContext.default_load_options 

1319 ) 

1320 load_options += {"_populate_existing": True} 

1321 execution_options = execution_options.union( 

1322 {"_sa_orm_load_options": load_options} 

1323 ) 

1324 

1325 return cls._return_orm_returning( 

1326 session, 

1327 statement, 

1328 params, 

1329 execution_options, 

1330 bind_arguments, 

1331 result, 

1332 ) 

1333 

1334 @classmethod 

1335 def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert: 

1336 self = cast( 

1337 BulkORMInsert, 

1338 super().create_for_statement(statement, compiler, **kw), 

1339 ) 

1340 

1341 if compiler is not None: 

1342 toplevel = not compiler.stack 

1343 else: 

1344 toplevel = True 

1345 if not toplevel: 

1346 return self 

1347 

1348 mapper = statement._propagate_attrs["plugin_subject"] 

1349 dml_strategy = statement._annotations.get("dml_strategy", "raw") 

1350 if dml_strategy == "bulk": 

1351 self._setup_for_bulk_insert(compiler) 

1352 elif dml_strategy == "orm": 

1353 self._setup_for_orm_insert(compiler, mapper) 

1354 

1355 return self 

1356 

1357 @classmethod 

1358 def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict): 

1359 return { 

1360 col.key if col is not None else k: v 

1361 for col, k, v in ( 

1362 (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items() 

1363 ) 

1364 } 

1365 

1366 def _setup_for_orm_insert(self, compiler, mapper): 

1367 statement = orm_level_statement = cast(dml.Insert, self.statement) 

1368 

1369 statement = self._setup_orm_returning( 

1370 compiler, 

1371 orm_level_statement, 

1372 statement, 

1373 dml_mapper=mapper, 

1374 use_supplemental_cols=False, 

1375 ) 

1376 self.statement = statement 

1377 

1378 def _setup_for_bulk_insert(self, compiler): 

1379 """establish an INSERT statement within the context of 

1380 bulk insert. 

1381 

1382 This method will be within the "conn.execute()" call that is invoked 

1383 by persistence._emit_insert_statement(). 

1384 

1385 """ 

1386 statement = orm_level_statement = cast(dml.Insert, self.statement) 

1387 an = statement._annotations 

1388 

1389 emit_insert_table, emit_insert_mapper = ( 

1390 an["_emit_insert_table"], 

1391 an["_emit_insert_mapper"], 

1392 ) 

1393 

1394 statement = statement._clone() 

1395 

1396 statement.table = emit_insert_table 

1397 if self._dict_parameters: 

1398 self._dict_parameters = { 

1399 col: val 

1400 for col, val in self._dict_parameters.items() 

1401 if col.table is emit_insert_table 

1402 } 

1403 

1404 statement = self._setup_orm_returning( 

1405 compiler, 

1406 orm_level_statement, 

1407 statement, 

1408 dml_mapper=emit_insert_mapper, 

1409 use_supplemental_cols=True, 

1410 ) 

1411 

1412 if ( 

1413 self.from_statement_ctx is not None 

1414 and self.from_statement_ctx.compile_options._is_star 

1415 ): 

1416 raise sa_exc.CompileError( 

1417 "Can't use RETURNING * with bulk ORM INSERT. " 

1418 "Please use a different INSERT form, such as INSERT..VALUES " 

1419 "or INSERT with a Core Connection" 

1420 ) 

1421 

1422 self.statement = statement 

1423 

1424 

1425@CompileState.plugin_for("orm", "update") 

1426class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): 

1427 @classmethod 

1428 def create_for_statement(cls, statement, compiler, **kw): 

1429 self = cls.__new__(cls) 

1430 

1431 dml_strategy = statement._annotations.get( 

1432 "dml_strategy", "unspecified" 

1433 ) 

1434 

1435 toplevel = not compiler.stack 

1436 

1437 if toplevel and dml_strategy == "bulk": 

1438 self._setup_for_bulk_update(statement, compiler) 

1439 elif ( 

1440 dml_strategy == "core_only" 

1441 or dml_strategy == "unspecified" 

1442 and "parententity" not in statement.table._annotations 

1443 ): 

1444 UpdateDMLState.__init__(self, statement, compiler, **kw) 

1445 elif not toplevel or dml_strategy in ("orm", "unspecified"): 

1446 self._setup_for_orm_update(statement, compiler) 

1447 

1448 return self 

1449 

1450 def _setup_for_orm_update(self, statement, compiler, **kw): 

1451 orm_level_statement = statement 

1452 

1453 toplevel = not compiler.stack 

1454 

1455 ext_info = statement.table._annotations["parententity"] 

1456 

1457 self.mapper = mapper = ext_info.mapper 

1458 

1459 self._resolved_values = self._get_resolved_values(mapper, statement) 

1460 

1461 self._init_global_attributes( 

1462 statement, 

1463 compiler, 

1464 toplevel=toplevel, 

1465 process_criteria_for_toplevel=toplevel, 

1466 ) 

1467 

1468 if statement._values: 

1469 self._resolved_values = dict(self._resolved_values) 

1470 

1471 new_stmt = statement._clone() 

1472 

1473 if new_stmt.table._annotations["parententity"] is mapper: 

1474 new_stmt.table = mapper.local_table 

1475 

1476 # note if the statement has _multi_values, these 

1477 # are passed through to the new statement, which will then raise 

1478 # InvalidRequestError because UPDATE doesn't support multi_values 

1479 # right now. 

1480 if statement._ordered_values: 

1481 new_stmt._ordered_values = self._resolved_values 

1482 elif statement._values: 

1483 new_stmt._values = self._resolved_values 

1484 

1485 new_crit = self._adjust_for_extra_criteria( 

1486 self.global_attributes, mapper 

1487 ) 

1488 if new_crit: 

1489 new_stmt = new_stmt.where(*new_crit) 

1490 

1491 # if we are against a lambda statement we might not be the 

1492 # topmost object that received per-execute annotations 

1493 

1494 # do this first as we need to determine if there is 

1495 # UPDATE..FROM 

1496 

1497 UpdateDMLState.__init__(self, new_stmt, compiler, **kw) 

1498 

1499 use_supplemental_cols = False 

1500 

1501 if not toplevel: 

1502 synchronize_session = None 

1503 else: 

1504 synchronize_session = compiler._annotations.get( 

1505 "synchronize_session", None 

1506 ) 

1507 can_use_returning = compiler._annotations.get( 

1508 "can_use_returning", None 

1509 ) 

1510 if can_use_returning is not False: 

1511 # even though pre_exec has determined basic 

1512 # can_use_returning for the dialect, if we are to use 

1513 # RETURNING we need to run can_use_returning() at this level 

1514 # unconditionally because is_delete_using was not known 

1515 # at the pre_exec level 

1516 can_use_returning = ( 

1517 synchronize_session == "fetch" 

1518 and self.can_use_returning( 

1519 compiler.dialect, mapper, is_multitable=self.is_multitable 

1520 ) 

1521 ) 

1522 

1523 if synchronize_session == "fetch" and can_use_returning: 

1524 use_supplemental_cols = True 

1525 

1526 # NOTE: we might want to RETURNING the actual columns to be 

1527 # synchronized also. however this is complicated and difficult 

1528 # to align against the behavior of "evaluate". Additionally, 

1529 # in a large number (if not the majority) of cases, we have the 

1530 # "evaluate" answer, usually a fixed value, in memory already and 

1531 # there's no need to re-fetch the same value 

1532 # over and over again. so perhaps if it could be RETURNING just 

1533 # the elements that were based on a SQL expression and not 

1534 # a constant. For now it doesn't quite seem worth it 

1535 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) 

1536 

1537 if toplevel: 

1538 new_stmt = self._setup_orm_returning( 

1539 compiler, 

1540 orm_level_statement, 

1541 new_stmt, 

1542 dml_mapper=mapper, 

1543 use_supplemental_cols=use_supplemental_cols, 

1544 ) 

1545 

1546 self.statement = new_stmt 

1547 

1548 def _setup_for_bulk_update(self, statement, compiler, **kw): 

1549 """establish an UPDATE statement within the context of 

1550 bulk insert. 

1551 

1552 This method will be within the "conn.execute()" call that is invoked 

1553 by persistence._emit_update_statement(). 

1554 

1555 """ 

1556 statement = cast(dml.Update, statement) 

1557 an = statement._annotations 

1558 

1559 emit_update_table, _ = ( 

1560 an["_emit_update_table"], 

1561 an["_emit_update_mapper"], 

1562 ) 

1563 

1564 statement = statement._clone() 

1565 statement.table = emit_update_table 

1566 

1567 UpdateDMLState.__init__(self, statement, compiler, **kw) 

1568 

1569 if self._ordered_values: 

1570 raise sa_exc.InvalidRequestError( 

1571 "bulk ORM UPDATE does not support ordered_values() for " 

1572 "custom UPDATE statements with bulk parameter sets. Use a " 

1573 "non-bulk UPDATE statement or use values()." 

1574 ) 

1575 

1576 if self._dict_parameters: 

1577 self._dict_parameters = { 

1578 col: val 

1579 for col, val in self._dict_parameters.items() 

1580 if col.table is emit_update_table 

1581 } 

1582 self.statement = statement 

1583 

1584 @classmethod 

1585 def orm_execute_statement( 

1586 cls, 

1587 session: Session, 

1588 statement: dml.Update, 

1589 params: _CoreAnyExecuteParams, 

1590 execution_options: OrmExecuteOptionsParameter, 

1591 bind_arguments: _BindArguments, 

1592 conn: Connection, 

1593 ) -> _result.Result: 

1594 update_options = execution_options.get( 

1595 "_sa_orm_update_options", cls.default_update_options 

1596 ) 

1597 

1598 if update_options._dml_strategy not in ( 

1599 "orm", 

1600 "auto", 

1601 "bulk", 

1602 "core_only", 

1603 ): 

1604 raise sa_exc.ArgumentError( 

1605 "Valid strategies for ORM UPDATE strategy " 

1606 "are 'orm', 'auto', 'bulk', 'core_only'" 

1607 ) 

1608 

1609 result: _result.Result[Unpack[TupleAny]] 

1610 

1611 if update_options._dml_strategy == "bulk": 

1612 enable_check_rowcount = not statement._where_criteria 

1613 

1614 assert update_options._synchronize_session != "fetch" 

1615 

1616 if ( 

1617 statement._where_criteria 

1618 and update_options._synchronize_session == "evaluate" 

1619 ): 

1620 raise sa_exc.InvalidRequestError( 

1621 "bulk synchronize of persistent objects not supported " 

1622 "when using bulk update with additional WHERE " 

1623 "criteria right now. add synchronize_session=None " 

1624 "execution option to bypass synchronize of persistent " 

1625 "objects." 

1626 ) 

1627 mapper = update_options._subject_mapper 

1628 assert mapper is not None 

1629 assert session._transaction is not None 

1630 result = _bulk_update( 

1631 mapper, 

1632 cast( 

1633 "Iterable[Dict[str, Any]]", 

1634 [params] if isinstance(params, dict) else params, 

1635 ), 

1636 session._transaction, 

1637 isstates=False, 

1638 update_changed_only=False, 

1639 use_orm_update_stmt=statement, 

1640 enable_check_rowcount=enable_check_rowcount, 

1641 ) 

1642 return cls.orm_setup_cursor_result( 

1643 session, 

1644 statement, 

1645 params, 

1646 execution_options, 

1647 bind_arguments, 

1648 result, 

1649 ) 

1650 else: 

1651 return super().orm_execute_statement( 

1652 session, 

1653 statement, 

1654 params, 

1655 execution_options, 

1656 bind_arguments, 

1657 conn, 

1658 ) 

1659 

1660 @classmethod 

1661 def can_use_returning( 

1662 cls, 

1663 dialect: Dialect, 

1664 mapper: Mapper[Any], 

1665 *, 

1666 is_multitable: bool = False, 

1667 is_update_from: bool = False, 

1668 is_delete_using: bool = False, 

1669 is_executemany: bool = False, 

1670 ) -> bool: 

1671 # normal answer for "should we use RETURNING" at all. 

1672 normal_answer = ( 

1673 dialect.update_returning and mapper.local_table.implicit_returning 

1674 ) 

1675 if not normal_answer: 

1676 return False 

1677 

1678 if is_executemany: 

1679 return dialect.update_executemany_returning 

1680 

1681 # these workarounds are currently hypothetical for UPDATE, 

1682 # unlike DELETE where they impact MariaDB 

1683 if is_update_from: 

1684 return dialect.update_returning_multifrom 

1685 

1686 elif is_multitable and not dialect.update_returning_multifrom: 

1687 raise sa_exc.CompileError( 

1688 f'Dialect "{dialect.name}" does not support RETURNING ' 

1689 "with UPDATE..FROM; for synchronize_session='fetch', " 

1690 "please add the additional execution option " 

1691 "'is_update_from=True' to the statement to indicate that " 

1692 "a separate SELECT should be used for this backend." 

1693 ) 

1694 

1695 return True 

1696 

1697 @classmethod 

1698 def _do_post_synchronize_bulk_evaluate( 

1699 cls, session, params, result, update_options 

1700 ): 

1701 if not params: 

1702 return 

1703 

1704 mapper = update_options._subject_mapper 

1705 pk_keys = [prop.key for prop in mapper._identity_key_props] 

1706 

1707 identity_map = session.identity_map 

1708 

1709 for param in params: 

1710 identity_key = mapper.identity_key_from_primary_key( 

1711 (param[key] for key in pk_keys), 

1712 update_options._identity_token, 

1713 ) 

1714 state = identity_map.fast_get_state(identity_key) 

1715 if not state: 

1716 continue 

1717 

1718 evaluated_keys = set(param).difference(pk_keys) 

1719 

1720 dict_ = state.dict 

1721 # only evaluate unmodified attributes 

1722 to_evaluate = state.unmodified.intersection(evaluated_keys) 

1723 for key in to_evaluate: 

1724 if key in dict_: 

1725 dict_[key] = param[key] 

1726 

1727 state.manager.dispatch.refresh(state, None, to_evaluate) 

1728 

1729 state._commit(dict_, list(to_evaluate)) 

1730 

1731 # attributes that were formerly modified instead get expired. 

1732 # this only gets hit if the session had pending changes 

1733 # and autoflush were set to False. 

1734 to_expire = evaluated_keys.intersection(dict_).difference( 

1735 to_evaluate 

1736 ) 

1737 if to_expire: 

1738 state._expire_attributes(dict_, to_expire) 

1739 

1740 @classmethod 

1741 def _do_post_synchronize_evaluate( 

1742 cls, session, statement, result, update_options 

1743 ): 

1744 matched_objects = cls._get_matched_objects_on_criteria( 

1745 update_options, 

1746 session.identity_map.all_states(), 

1747 ) 

1748 

1749 cls._apply_update_set_values_to_objects( 

1750 session, 

1751 update_options, 

1752 statement, 

1753 [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], 

1754 ) 

1755 

1756 @classmethod 

1757 def _do_post_synchronize_fetch( 

1758 cls, session, statement, result, update_options 

1759 ): 

1760 target_mapper = update_options._subject_mapper 

1761 

1762 returned_defaults_rows = result.returned_defaults_rows 

1763 if returned_defaults_rows: 

1764 pk_rows = cls._interpret_returning_rows( 

1765 target_mapper, returned_defaults_rows 

1766 ) 

1767 

1768 matched_rows = [ 

1769 tuple(row) + (update_options._identity_token,) 

1770 for row in pk_rows 

1771 ] 

1772 else: 

1773 matched_rows = update_options._matched_rows 

1774 

1775 objs = [ 

1776 session.identity_map[identity_key] 

1777 for identity_key in [ 

1778 target_mapper.identity_key_from_primary_key( 

1779 list(primary_key), 

1780 identity_token=identity_token, 

1781 ) 

1782 for primary_key, identity_token in [ 

1783 (row[0:-1], row[-1]) for row in matched_rows 

1784 ] 

1785 if update_options._identity_token is None 

1786 or identity_token == update_options._identity_token 

1787 ] 

1788 if identity_key in session.identity_map 

1789 ] 

1790 

1791 if not objs: 

1792 return 

1793 

1794 cls._apply_update_set_values_to_objects( 

1795 session, 

1796 update_options, 

1797 statement, 

1798 [ 

1799 ( 

1800 obj, 

1801 attributes.instance_state(obj), 

1802 attributes.instance_dict(obj), 

1803 ) 

1804 for obj in objs 

1805 ], 

1806 ) 

1807 

1808 @classmethod 

1809 def _apply_update_set_values_to_objects( 

1810 cls, session, update_options, statement, matched_objects 

1811 ): 

1812 """apply values to objects derived from an update statement, e.g. 

1813 UPDATE..SET <values> 

1814 

1815 """ 

1816 mapper = update_options._subject_mapper 

1817 target_cls = mapper.class_ 

1818 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) 

1819 resolved_values = cls._get_resolved_values(mapper, statement) 

1820 resolved_keys_as_propnames = cls._resolved_keys_as_propnames( 

1821 mapper, resolved_values 

1822 ) 

1823 value_evaluators = {} 

1824 for key, value in resolved_keys_as_propnames: 

1825 try: 

1826 _evaluator = evaluator_compiler.process( 

1827 coercions.expect(roles.ExpressionElementRole, value) 

1828 ) 

1829 except evaluator.UnevaluatableError: 

1830 pass 

1831 else: 

1832 value_evaluators[key] = _evaluator 

1833 

1834 evaluated_keys = list(value_evaluators.keys()) 

1835 attrib = {k for k, v in resolved_keys_as_propnames} 

1836 

1837 states = set() 

1838 for obj, state, dict_ in matched_objects: 

1839 to_evaluate = state.unmodified.intersection(evaluated_keys) 

1840 

1841 for key in to_evaluate: 

1842 if key in dict_: 

1843 # only run eval for attributes that are present. 

1844 dict_[key] = value_evaluators[key](obj) 

1845 

1846 state.manager.dispatch.refresh(state, None, to_evaluate) 

1847 

1848 state._commit(dict_, list(to_evaluate)) 

1849 

1850 # attributes that were formerly modified instead get expired. 

1851 # this only gets hit if the session had pending changes 

1852 # and autoflush were set to False. 

1853 to_expire = attrib.intersection(dict_).difference(to_evaluate) 

1854 if to_expire: 

1855 state._expire_attributes(dict_, to_expire) 

1856 

1857 states.add(state) 

1858 session._register_altered(states) 

1859 

1860 

1861@CompileState.plugin_for("orm", "delete") 

1862class BulkORMDelete(BulkUDCompileState, DeleteDMLState): 

1863 @classmethod 

1864 def create_for_statement(cls, statement, compiler, **kw): 

1865 self = cls.__new__(cls) 

1866 

1867 dml_strategy = statement._annotations.get( 

1868 "dml_strategy", "unspecified" 

1869 ) 

1870 

1871 if ( 

1872 dml_strategy == "core_only" 

1873 or dml_strategy == "unspecified" 

1874 and "parententity" not in statement.table._annotations 

1875 ): 

1876 DeleteDMLState.__init__(self, statement, compiler, **kw) 

1877 return self 

1878 

1879 toplevel = not compiler.stack 

1880 

1881 orm_level_statement = statement 

1882 

1883 ext_info = statement.table._annotations["parententity"] 

1884 self.mapper = mapper = ext_info.mapper 

1885 

1886 self._init_global_attributes( 

1887 statement, 

1888 compiler, 

1889 toplevel=toplevel, 

1890 process_criteria_for_toplevel=toplevel, 

1891 ) 

1892 

1893 new_stmt = statement._clone() 

1894 

1895 if new_stmt.table._annotations["parententity"] is mapper: 

1896 new_stmt.table = mapper.local_table 

1897 

1898 new_crit = cls._adjust_for_extra_criteria( 

1899 self.global_attributes, mapper 

1900 ) 

1901 if new_crit: 

1902 new_stmt = new_stmt.where(*new_crit) 

1903 

1904 # do this first as we need to determine if there is 

1905 # DELETE..FROM 

1906 DeleteDMLState.__init__(self, new_stmt, compiler, **kw) 

1907 

1908 use_supplemental_cols = False 

1909 

1910 if not toplevel: 

1911 synchronize_session = None 

1912 else: 

1913 synchronize_session = compiler._annotations.get( 

1914 "synchronize_session", None 

1915 ) 

1916 can_use_returning = compiler._annotations.get( 

1917 "can_use_returning", None 

1918 ) 

1919 if can_use_returning is not False: 

1920 # even though pre_exec has determined basic 

1921 # can_use_returning for the dialect, if we are to use 

1922 # RETURNING we need to run can_use_returning() at this level 

1923 # unconditionally because is_delete_using was not known 

1924 # at the pre_exec level 

1925 can_use_returning = ( 

1926 synchronize_session == "fetch" 

1927 and self.can_use_returning( 

1928 compiler.dialect, 

1929 mapper, 

1930 is_multitable=self.is_multitable, 

1931 is_delete_using=compiler._annotations.get( 

1932 "is_delete_using", False 

1933 ), 

1934 ) 

1935 ) 

1936 

1937 if can_use_returning: 

1938 use_supplemental_cols = True 

1939 

1940 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) 

1941 

1942 if toplevel: 

1943 new_stmt = self._setup_orm_returning( 

1944 compiler, 

1945 orm_level_statement, 

1946 new_stmt, 

1947 dml_mapper=mapper, 

1948 use_supplemental_cols=use_supplemental_cols, 

1949 ) 

1950 

1951 self.statement = new_stmt 

1952 

1953 return self 

1954 

1955 @classmethod 

1956 def orm_execute_statement( 

1957 cls, 

1958 session: Session, 

1959 statement: dml.Delete, 

1960 params: _CoreAnyExecuteParams, 

1961 execution_options: OrmExecuteOptionsParameter, 

1962 bind_arguments: _BindArguments, 

1963 conn: Connection, 

1964 ) -> _result.Result: 

1965 update_options = execution_options.get( 

1966 "_sa_orm_update_options", cls.default_update_options 

1967 ) 

1968 

1969 if update_options._dml_strategy == "bulk": 

1970 raise sa_exc.InvalidRequestError( 

1971 "Bulk ORM DELETE not supported right now. " 

1972 "Statement may be invoked at the " 

1973 "Core level using " 

1974 "session.connection().execute(stmt, parameters)" 

1975 ) 

1976 

1977 if update_options._dml_strategy not in ("orm", "auto", "core_only"): 

1978 raise sa_exc.ArgumentError( 

1979 "Valid strategies for ORM DELETE strategy are 'orm', 'auto', " 

1980 "'core_only'" 

1981 ) 

1982 

1983 return super().orm_execute_statement( 

1984 session, statement, params, execution_options, bind_arguments, conn 

1985 ) 

1986 

1987 @classmethod 

1988 def can_use_returning( 

1989 cls, 

1990 dialect: Dialect, 

1991 mapper: Mapper[Any], 

1992 *, 

1993 is_multitable: bool = False, 

1994 is_update_from: bool = False, 

1995 is_delete_using: bool = False, 

1996 is_executemany: bool = False, 

1997 ) -> bool: 

1998 # normal answer for "should we use RETURNING" at all. 

1999 normal_answer = ( 

2000 dialect.delete_returning and mapper.local_table.implicit_returning 

2001 ) 

2002 if not normal_answer: 

2003 return False 

2004 

2005 # now get into special workarounds because MariaDB supports 

2006 # DELETE...RETURNING but not DELETE...USING...RETURNING. 

2007 if is_delete_using: 

2008 # is_delete_using hint was passed. use 

2009 # additional dialect feature (True for PG, False for MariaDB) 

2010 return dialect.delete_returning_multifrom 

2011 

2012 elif is_multitable and not dialect.delete_returning_multifrom: 

2013 # is_delete_using hint was not passed, but we determined 

2014 # at compile time that this is in fact a DELETE..USING. 

2015 # it's too late to continue since we did not pre-SELECT. 

2016 # raise that we need that hint up front. 

2017 

2018 raise sa_exc.CompileError( 

2019 f'Dialect "{dialect.name}" does not support RETURNING ' 

2020 "with DELETE..USING; for synchronize_session='fetch', " 

2021 "please add the additional execution option " 

2022 "'is_delete_using=True' to the statement to indicate that " 

2023 "a separate SELECT should be used for this backend." 

2024 ) 

2025 

2026 return True 

2027 

2028 @classmethod 

2029 def _do_post_synchronize_evaluate( 

2030 cls, session, statement, result, update_options 

2031 ): 

2032 matched_objects = cls._get_matched_objects_on_criteria( 

2033 update_options, 

2034 session.identity_map.all_states(), 

2035 ) 

2036 

2037 to_delete = [] 

2038 

2039 for _, state, dict_, is_partially_expired in matched_objects: 

2040 if is_partially_expired: 

2041 state._expire(dict_, session.identity_map._modified) 

2042 else: 

2043 to_delete.append(state) 

2044 

2045 if to_delete: 

2046 session._remove_newly_deleted(to_delete) 

2047 

2048 @classmethod 

2049 def _do_post_synchronize_fetch( 

2050 cls, session, statement, result, update_options 

2051 ): 

2052 target_mapper = update_options._subject_mapper 

2053 

2054 returned_defaults_rows = result.returned_defaults_rows 

2055 

2056 if returned_defaults_rows: 

2057 pk_rows = cls._interpret_returning_rows( 

2058 target_mapper, returned_defaults_rows 

2059 ) 

2060 

2061 matched_rows = [ 

2062 tuple(row) + (update_options._identity_token,) 

2063 for row in pk_rows 

2064 ] 

2065 else: 

2066 matched_rows = update_options._matched_rows 

2067 

2068 for row in matched_rows: 

2069 primary_key = row[0:-1] 

2070 identity_token = row[-1] 

2071 

2072 # TODO: inline this and call remove_newly_deleted 

2073 # once 

2074 identity_key = target_mapper.identity_key_from_primary_key( 

2075 list(primary_key), 

2076 identity_token=identity_token, 

2077 ) 

2078 if identity_key in session.identity_map: 

2079 session._remove_newly_deleted( 

2080 [ 

2081 attributes.instance_state( 

2082 session.identity_map[identity_key] 

2083 ) 

2084 ] 

2085 )