Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy_utils/functions/orm.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

252 statements  

1from collections import OrderedDict 

2from functools import partial 

3from inspect import isclass 

4from operator import attrgetter 

5 

6import sqlalchemy as sa 

7from sqlalchemy.engine.interfaces import Dialect 

8from sqlalchemy.ext.hybrid import hybrid_property 

9from sqlalchemy.orm import ColumnProperty, mapperlib, RelationshipProperty 

10from sqlalchemy.orm.attributes import InstrumentedAttribute 

11from sqlalchemy.orm.exc import UnmappedInstanceError 

12 

13try: 

14 from sqlalchemy.orm.context import _ColumnEntity, _MapperEntity 

15except ImportError: # SQLAlchemy <1.4 

16 from sqlalchemy.orm.query import _ColumnEntity, _MapperEntity 

17 

18from sqlalchemy.orm.session import object_session 

19from sqlalchemy.orm.util import AliasedInsp 

20 

21from ..utils import is_sequence 

22 

23 

24def get_class_by_table(base, table, data=None): 

25 """ 

26 Return declarative class associated with given table. If no class is found 

27 this function returns `None`. If multiple classes were found (polymorphic 

28 cases) additional `data` parameter can be given to hint which class 

29 to return. 

30 

31 :: 

32 

33 class User(Base): 

34 __tablename__ = 'entity' 

35 id = sa.Column(sa.Integer, primary_key=True) 

36 name = sa.Column(sa.String) 

37 

38 

39 get_class_by_table(Base, User.__table__) # User class 

40 

41 

42 This function also supports models using single table inheritance. 

43 Additional data paratemer should be provided in these case. 

44 

45 :: 

46 

47 class Entity(Base): 

48 __tablename__ = 'entity' 

49 id = sa.Column(sa.Integer, primary_key=True) 

50 name = sa.Column(sa.String) 

51 type = sa.Column(sa.String) 

52 __mapper_args__ = { 

53 'polymorphic_on': type, 

54 'polymorphic_identity': 'entity' 

55 } 

56 

57 class User(Entity): 

58 __mapper_args__ = { 

59 'polymorphic_identity': 'user' 

60 } 

61 

62 

63 # Entity class 

64 get_class_by_table(Base, Entity.__table__, {'type': 'entity'}) 

65 

66 # User class 

67 get_class_by_table(Base, Entity.__table__, {'type': 'user'}) 

68 

69 

70 :param base: Declarative model base 

71 :param table: SQLAlchemy Table object 

72 :param data: Data row to determine the class in polymorphic scenarios 

73 :return: Declarative class or None. 

74 """ 

75 found_classes = { 

76 c 

77 for c in _get_class_registry(base).values() 

78 if hasattr(c, '__table__') and c.__table__ is table 

79 } 

80 if len(found_classes) > 1: 

81 if not data: 

82 raise ValueError( 

83 "Multiple declarative classes found for table '{}'. " 

84 'Please provide data parameter for this function to be able ' 

85 'to determine polymorphic scenarios.'.format(table.name) 

86 ) 

87 else: 

88 for cls in found_classes: 

89 mapper = sa.inspect(cls) 

90 polymorphic_on = mapper.polymorphic_on.name 

91 if polymorphic_on in data: 

92 if data[polymorphic_on] == mapper.polymorphic_identity: 

93 return cls 

94 raise ValueError( 

95 "Multiple declarative classes found for table '{}'. Given " 

96 'data row does not match any polymorphic identity of the ' 

97 'found classes.'.format(table.name) 

98 ) 

99 elif found_classes: 

100 return found_classes.pop() 

101 return None 

102 

103 

104def get_type(expr): 

105 """ 

106 Return the associated type with given Column, InstrumentedAttribute, 

107 ColumnProperty, RelationshipProperty or other similar SQLAlchemy construct. 

108 

109 For constructs wrapping columns this is the column type. For relationships 

110 this function returns the relationship mapper class. 

111 

112 :param expr: 

113 SQLAlchemy Column, InstrumentedAttribute, ColumnProperty or other 

114 similar SA construct. 

115 

116 :: 

117 

118 class User(Base): 

119 __tablename__ = 'user' 

120 id = sa.Column(sa.Integer, primary_key=True) 

121 name = sa.Column(sa.String) 

122 

123 

124 class Article(Base): 

125 __tablename__ = 'article' 

126 id = sa.Column(sa.Integer, primary_key=True) 

127 author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) 

128 author = sa.orm.relationship(User) 

129 

130 

131 get_type(User.__table__.c.name) # sa.String() 

132 get_type(User.name) # sa.String() 

133 get_type(User.name.property) # sa.String() 

134 

135 get_type(Article.author) # User 

136 

137 

138 .. versionadded: 0.30.9 

139 """ 

140 if hasattr(expr, 'type'): 

141 return expr.type 

142 elif isinstance(expr, InstrumentedAttribute): 

143 expr = expr.property 

144 

145 if isinstance(expr, ColumnProperty): 

146 return expr.columns[0].type 

147 elif isinstance(expr, RelationshipProperty): 

148 return expr.mapper.class_ 

149 raise TypeError("Couldn't inspect type.") 

150 

151 

152def cast_if(expression, type_): 

153 """ 

154 Produce a CAST expression but only if given expression is not of given type 

155 already. 

156 

157 Assume we have a model with two fields id (Integer) and name (String). 

158 

159 :: 

160 

161 import sqlalchemy as sa 

162 from sqlalchemy_utils import cast_if 

163 

164 

165 cast_if(User.id, sa.Integer) # "user".id 

166 cast_if(User.name, sa.String) # "user".name 

167 cast_if(User.id, sa.String) # CAST("user".id AS TEXT) 

168 

169 

170 This function supports scalar values as well. 

171 

172 :: 

173 

174 cast_if(1, sa.Integer) # 1 

175 cast_if('text', sa.String) # 'text' 

176 cast_if(1, sa.String) # CAST(1 AS TEXT) 

177 

178 

179 :param expression: 

180 A SQL expression, such as a ColumnElement expression or a Python string 

181 which will be coerced into a bound literal value. 

182 :param type_: 

183 A TypeEngine class or instance indicating the type to which the CAST 

184 should apply. 

185 

186 .. versionadded: 0.30.14 

187 """ 

188 try: 

189 expr_type = get_type(expression) 

190 except TypeError: 

191 expr_type = expression 

192 check_type = type_().python_type 

193 else: 

194 check_type = type_ 

195 

196 return ( 

197 sa.cast(expression, type_) 

198 if not isinstance(expr_type, check_type) 

199 else expression 

200 ) 

201 

202 

203def get_column_key(model, column): 

204 """ 

205 Return the key for given column in given model. 

206 

207 :param model: SQLAlchemy declarative model object 

208 

209 :: 

210 

211 class User(Base): 

212 __tablename__ = 'user' 

213 id = sa.Column(sa.Integer, primary_key=True) 

214 name = sa.Column('_name', sa.String) 

215 

216 

217 get_column_key(User, User.__table__.c._name) # 'name' 

218 

219 .. versionadded: 0.26.5 

220 

221 .. versionchanged: 0.27.11 

222 Throws UnmappedColumnError instead of ValueError when no property was 

223 found for given column. This is consistent with how SQLAlchemy works. 

224 """ 

225 mapper = sa.inspect(model) 

226 try: 

227 return mapper.get_property_by_column(column).key 

228 except sa.orm.exc.UnmappedColumnError: 

229 for key, c in mapper.columns.items(): 

230 if c.name == column.name and c.table is column.table: 

231 return key 

232 raise sa.orm.exc.UnmappedColumnError( 

233 f'No column {column} is configured on mapper {mapper}...' 

234 ) 

235 

236 

237def get_mapper(mixed): 

238 """ 

239 Return related SQLAlchemy Mapper for given SQLAlchemy object. 

240 

241 :param mixed: SQLAlchemy Table / Alias / Mapper / declarative model object 

242 

243 :: 

244 

245 from sqlalchemy_utils import get_mapper 

246 

247 

248 get_mapper(User) 

249 

250 get_mapper(User()) 

251 

252 get_mapper(User.__table__) 

253 

254 get_mapper(User.__mapper__) 

255 

256 get_mapper(sa.orm.aliased(User)) 

257 

258 get_mapper(sa.orm.aliased(User.__table__)) 

259 

260 

261 Raises: 

262 ValueError: if multiple mappers were found for given argument 

263 

264 .. versionadded: 0.26.1 

265 """ 

266 if isinstance(mixed, _MapperEntity): 

267 mixed = mixed.expr 

268 elif isinstance(mixed, sa.Column): 

269 mixed = mixed.table 

270 elif isinstance(mixed, _ColumnEntity): 

271 mixed = mixed.expr 

272 

273 if isinstance(mixed, sa.orm.Mapper): 

274 return mixed 

275 if isinstance(mixed, sa.orm.util.AliasedClass): 

276 return sa.inspect(mixed).mapper 

277 if isinstance(mixed, sa.sql.selectable.Alias): 

278 mixed = mixed.element 

279 if isinstance(mixed, AliasedInsp): 

280 return mixed.mapper 

281 if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): 

282 mixed = mixed.class_ 

283 if isinstance(mixed, sa.Table): 

284 if hasattr(mapperlib, '_all_registries'): 

285 all_mappers = set() 

286 for mapper_registry in mapperlib._all_registries(): 

287 all_mappers.update(mapper_registry.mappers) 

288 else: # SQLAlchemy <1.4 

289 all_mappers = mapperlib._mapper_registry 

290 mappers = [mapper for mapper in all_mappers if mixed in mapper.tables] 

291 if len(mappers) > 1: 

292 raise ValueError("Multiple mappers found for table '%s'." % mixed.name) 

293 elif not mappers: 

294 raise ValueError("Could not get mapper for table '%s'." % mixed.name) 

295 else: 

296 return mappers[0] 

297 if not isclass(mixed): 

298 mixed = type(mixed) 

299 return sa.inspect(mixed) 

300 

301 

302def get_bind(obj): 

303 """ 

304 Return the bind for given SQLAlchemy Engine / Connection / declarative 

305 model object. 

306 

307 :param obj: SQLAlchemy Engine / Connection / declarative model object 

308 

309 :: 

310 

311 from sqlalchemy_utils import get_bind 

312 

313 

314 get_bind(session) # Connection object 

315 

316 get_bind(user) 

317 

318 """ 

319 if hasattr(obj, 'bind'): 

320 conn = obj.bind 

321 else: 

322 try: 

323 conn = object_session(obj).bind 

324 except UnmappedInstanceError: 

325 conn = obj 

326 

327 if not hasattr(conn, 'execute'): 

328 raise TypeError( 

329 'This method accepts only Session, Engine, Connection and ' 

330 'declarative model objects.' 

331 ) 

332 return conn 

333 

334 

335def get_primary_keys(mixed): 

336 """ 

337 Return an OrderedDict of all primary keys for given Table object, 

338 declarative class or declarative class instance. 

339 

340 :param mixed: 

341 SA Table object, SA declarative class or SA declarative class instance 

342 

343 :: 

344 

345 get_primary_keys(User) 

346 

347 get_primary_keys(User()) 

348 

349 get_primary_keys(User.__table__) 

350 

351 get_primary_keys(User.__mapper__) 

352 

353 get_primary_keys(sa.orm.aliased(User)) 

354 

355 get_primary_keys(sa.orm.aliased(User.__table__)) 

356 

357 

358 .. versionchanged: 0.25.3 

359 Made the function return an ordered dictionary instead of generator. 

360 This change was made to support primary key aliases. 

361 

362 Renamed this function to 'get_primary_keys', formerly 'primary_keys' 

363 

364 .. seealso:: :func:`get_columns` 

365 """ 

366 return OrderedDict( 

367 ( 

368 (key, column) 

369 for key, column in get_columns(mixed).items() 

370 if column.primary_key 

371 ) 

372 ) 

373 

374 

375def get_tables(mixed): 

376 """ 

377 Return a set of tables associated with given SQLAlchemy object. 

378 

379 Let's say we have three classes which use joined table inheritance 

380 TextItem, Article and BlogPost. Article and BlogPost inherit TextItem. 

381 

382 :: 

383 

384 get_tables(Article) # set([Table('article', ...), Table('text_item')]) 

385 

386 get_tables(Article()) 

387 

388 get_tables(Article.__mapper__) 

389 

390 

391 If the TextItem entity is using with_polymorphic='*' then this function 

392 returns all child tables (article and blog_post) as well. 

393 

394 :: 

395 

396 

397 get_tables(TextItem) # set([Table('text_item', ...)], ...]) 

398 

399 

400 .. versionadded: 0.26.0 

401 

402 :param mixed: 

403 SQLAlchemy Mapper, Declarative class, Column, InstrumentedAttribute or 

404 a SA Alias object wrapping any of these objects. 

405 """ 

406 if isinstance(mixed, sa.Table): 

407 return [mixed] 

408 elif isinstance(mixed, sa.Column): 

409 return [mixed.table] 

410 elif isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): 

411 return mixed.parent.tables 

412 elif isinstance(mixed, _ColumnEntity): 

413 mixed = mixed.expr 

414 

415 mapper = get_mapper(mixed) 

416 

417 polymorphic_mappers = get_polymorphic_mappers(mapper) 

418 if polymorphic_mappers: 

419 tables = sum((m.tables for m in polymorphic_mappers), []) 

420 else: 

421 tables = mapper.tables 

422 return tables 

423 

424 

425def get_columns(mixed): 

426 """ 

427 Return a collection of all Column objects for given SQLAlchemy 

428 object. 

429 

430 The type of the collection depends on the type of the object to return the 

431 columns from. 

432 

433 :: 

434 

435 get_columns(User) 

436 

437 get_columns(User()) 

438 

439 get_columns(User.__table__) 

440 

441 get_columns(User.__mapper__) 

442 

443 get_columns(sa.orm.aliased(User)) 

444 

445 get_columns(sa.orm.alised(User.__table__)) 

446 

447 

448 :param mixed: 

449 SA Table object, SA Mapper, SA declarative class, SA declarative class 

450 instance or an alias of any of these objects 

451 """ 

452 if isinstance(mixed, sa.sql.selectable.Selectable): 

453 try: 

454 return mixed.selected_columns 

455 except AttributeError: # SQLAlchemy <1.4 

456 return mixed.c 

457 if isinstance(mixed, sa.orm.util.AliasedClass): 

458 return sa.inspect(mixed).mapper.columns 

459 if isinstance(mixed, sa.orm.Mapper): 

460 return mixed.columns 

461 if isinstance(mixed, InstrumentedAttribute): 

462 return mixed.property.columns 

463 if isinstance(mixed, ColumnProperty): 

464 return mixed.columns 

465 if isinstance(mixed, sa.Column): 

466 return [mixed] 

467 if not isclass(mixed): 

468 mixed = mixed.__class__ 

469 return sa.inspect(mixed).columns 

470 

471 

472def table_name(obj): 

473 """ 

474 Return table name of given target, declarative class or the 

475 table name where the declarative attribute is bound to. 

476 """ 

477 class_ = getattr(obj, 'class_', obj) 

478 

479 try: 

480 return class_.__tablename__ 

481 except AttributeError: 

482 pass 

483 

484 try: 

485 return class_.__table__.name 

486 except AttributeError: 

487 pass 

488 

489 

490def getattrs(obj, attrs): 

491 return map(partial(getattr, obj), attrs) 

492 

493 

494def quote(mixed, ident): 

495 """ 

496 Conditionally quote an identifier. 

497 :: 

498 

499 

500 from sqlalchemy_utils import quote 

501 

502 

503 engine = create_engine('sqlite:///:memory:') 

504 

505 quote(engine, 'order') 

506 # '"order"' 

507 

508 quote(engine, 'some_other_identifier') 

509 # 'some_other_identifier' 

510 

511 

512 :param mixed: SQLAlchemy Session / Connection / Engine / Dialect object. 

513 :param ident: identifier to conditionally quote 

514 """ 

515 if isinstance(mixed, Dialect): 

516 dialect = mixed 

517 else: 

518 dialect = get_bind(mixed).dialect 

519 return dialect.preparer(dialect).quote(ident) 

520 

521 

522def _get_query_compile_state(query): 

523 if hasattr(query, '_compile_state'): 

524 return query._compile_state() 

525 else: # SQLAlchemy <1.4 

526 return query 

527 

528 

529def get_polymorphic_mappers(mixed): 

530 if isinstance(mixed, AliasedInsp): 

531 return mixed.with_polymorphic_mappers 

532 else: 

533 return mixed.polymorphic_map.values() 

534 

535 

536def get_descriptor(entity, attr): 

537 mapper = sa.inspect(entity) 

538 

539 for key, descriptor in get_all_descriptors(mapper).items(): 

540 if attr == key: 

541 prop = descriptor.property if hasattr(descriptor, 'property') else None 

542 if isinstance(prop, ColumnProperty): 

543 if isinstance(entity, sa.orm.util.AliasedClass): 

544 for c in mapper.selectable.c: 

545 if c.key == attr: 

546 return c 

547 else: 

548 # If the property belongs to a class that uses 

549 # polymorphic inheritance we have to take into account 

550 # situations where the attribute exists in child class 

551 # but not in parent class. 

552 return getattr(prop.parent.class_, attr) 

553 else: 

554 # Handle synonyms, relationship properties and hybrid 

555 # properties 

556 

557 if isinstance(entity, sa.orm.util.AliasedClass): 

558 return getattr(entity, attr) 

559 try: 

560 return getattr(mapper.class_, attr) 

561 except AttributeError: 

562 pass 

563 

564 

565def get_all_descriptors(expr): 

566 if isinstance(expr, sa.sql.selectable.Selectable): 

567 return expr.c 

568 insp = sa.inspect(expr) 

569 try: 

570 polymorphic_mappers = get_polymorphic_mappers(insp) 

571 except sa.exc.NoInspectionAvailable: 

572 return get_mapper(expr).all_orm_descriptors 

573 else: 

574 attrs = dict(get_mapper(expr).all_orm_descriptors) 

575 for submapper in polymorphic_mappers: 

576 for key, descriptor in submapper.all_orm_descriptors.items(): 

577 if key not in attrs: 

578 attrs[key] = descriptor 

579 return attrs 

580 

581 

582def get_hybrid_properties(model): 

583 """ 

584 Returns a dictionary of hybrid property keys and hybrid properties for 

585 given SQLAlchemy declarative model / mapper. 

586 

587 

588 Consider the following model 

589 

590 :: 

591 

592 

593 from sqlalchemy.ext.hybrid import hybrid_property 

594 

595 

596 class Category(Base): 

597 __tablename__ = 'category' 

598 id = sa.Column(sa.Integer, primary_key=True) 

599 name = sa.Column(sa.Unicode(255)) 

600 

601 @hybrid_property 

602 def lowercase_name(self): 

603 return self.name.lower() 

604 

605 @lowercase_name.expression 

606 def lowercase_name(cls): 

607 return sa.func.lower(cls.name) 

608 

609 

610 You can now easily get a list of all hybrid property names 

611 

612 :: 

613 

614 

615 from sqlalchemy_utils import get_hybrid_properties 

616 

617 

618 get_hybrid_properties(Category).keys() # ['lowercase_name'] 

619 

620 

621 This function also supports aliased classes 

622 

623 :: 

624 

625 

626 get_hybrid_properties( 

627 sa.orm.aliased(Category) 

628 ).keys() # ['lowercase_name'] 

629 

630 

631 .. versionchanged: 0.26.7 

632 This function now returns a dictionary instead of generator 

633 

634 .. versionchanged: 0.30.15 

635 Added support for aliased classes 

636 

637 :param model: SQLAlchemy declarative model or mapper 

638 """ 

639 return { 

640 key: prop 

641 for key, prop in get_mapper(model).all_orm_descriptors.items() 

642 if isinstance(prop, hybrid_property) 

643 } 

644 

645 

646def get_declarative_base(model): 

647 """ 

648 Returns the declarative base for given model class. 

649 

650 :param model: SQLAlchemy declarative model 

651 """ 

652 for parent in model.__bases__: 

653 try: 

654 parent.metadata 

655 return get_declarative_base(parent) 

656 except AttributeError: 

657 pass 

658 return model 

659 

660 

661def getdotattr(obj_or_class, dot_path, condition=None): 

662 """ 

663 Allow dot-notated strings to be passed to `getattr`. 

664 

665 :: 

666 

667 getdotattr(SubSection, 'section.document') 

668 

669 getdotattr(subsection, 'section.document') 

670 

671 

672 :param obj_or_class: Any object or class 

673 :param dot_path: Attribute path with dot mark as separator 

674 """ 

675 last = obj_or_class 

676 

677 for path in str(dot_path).split('.'): 

678 getter = attrgetter(path) 

679 

680 if is_sequence(last): 

681 tmp = [] 

682 for element in last: 

683 value = getter(element) 

684 if is_sequence(value): 

685 tmp.extend(value) 

686 else: 

687 tmp.append(value) 

688 last = tmp 

689 elif isinstance(last, InstrumentedAttribute): 

690 last = getter(last.property.mapper.class_) 

691 elif last is None: 

692 return None 

693 else: 

694 last = getter(last) 

695 if condition is not None: 

696 if is_sequence(last): 

697 last = [v for v in last if condition(v)] 

698 else: 

699 if not condition(last): 

700 return None 

701 

702 return last 

703 

704 

705def is_deleted(obj): 

706 return obj in sa.orm.object_session(obj).deleted 

707 

708 

709def has_changes(obj, attrs=None, exclude=None): 

710 """ 

711 Simple shortcut function for checking if given attributes of given 

712 declarative model object have changed during the session. Without 

713 parameters this checks if given object has any modificiations. Additionally 

714 exclude parameter can be given to check if given object has any changes 

715 in any attributes other than the ones given in exclude. 

716 

717 

718 :: 

719 

720 

721 from sqlalchemy_utils import has_changes 

722 

723 

724 user = User() 

725 

726 has_changes(user, 'name') # False 

727 

728 user.name = 'someone' 

729 

730 has_changes(user, 'name') # True 

731 

732 has_changes(user) # True 

733 

734 

735 You can check multiple attributes as well. 

736 :: 

737 

738 

739 has_changes(user, ['age']) # True 

740 

741 has_changes(user, ['name', 'age']) # True 

742 

743 

744 This function also supports excluding certain attributes. 

745 

746 :: 

747 

748 has_changes(user, exclude=['name']) # False 

749 

750 has_changes(user, exclude=['age']) # True 

751 

752 .. versionchanged: 0.26.6 

753 Added support for multiple attributes and exclude parameter. 

754 

755 :param obj: SQLAlchemy declarative model object 

756 :param attrs: Names of the attributes 

757 :param exclude: Names of the attributes to exclude 

758 """ 

759 if attrs: 

760 if isinstance(attrs, str): 

761 return sa.inspect(obj).attrs.get(attrs).history.has_changes() 

762 else: 

763 return any(has_changes(obj, attr) for attr in attrs) 

764 else: 

765 if exclude is None: 

766 exclude = [] 

767 return any( 

768 attr.history.has_changes() 

769 for key, attr in sa.inspect(obj).attrs.items() 

770 if key not in exclude 

771 ) 

772 

773 

774def is_loaded(obj, prop): 

775 """ 

776 Return whether or not given property of given object has been loaded. 

777 

778 :: 

779 

780 class Article(Base): 

781 __tablename__ = 'article' 

782 id = sa.Column(sa.Integer, primary_key=True) 

783 name = sa.Column(sa.String) 

784 content = sa.orm.deferred(sa.Column(sa.String)) 

785 

786 

787 article = session.query(Article).get(5) 

788 

789 # name gets loaded since its not a deferred property 

790 assert is_loaded(article, 'name') 

791 

792 # content has not yet been loaded since its a deferred property 

793 assert not is_loaded(article, 'content') 

794 

795 

796 .. versionadded: 0.27.8 

797 

798 :param obj: SQLAlchemy declarative model object 

799 :param prop: Name of the property or InstrumentedAttribute 

800 """ 

801 return prop not in sa.inspect(obj).unloaded 

802 

803 

804def identity(obj_or_class): 

805 """ 

806 Return the identity of given sqlalchemy declarative model class or instance 

807 as a tuple. This differs from obj._sa_instance_state.identity in a way that 

808 it always returns the identity even if object is still in transient state ( 

809 new object that is not yet persisted into database). Also for classes it 

810 returns the identity attributes. 

811 

812 :: 

813 

814 from sqlalchemy import inspect 

815 from sqlalchemy_utils import identity 

816 

817 

818 user = User(name='John Matrix') 

819 session.add(user) 

820 identity(user) # None 

821 inspect(user).identity # None 

822 

823 session.flush() # User now has id but is still in transient state 

824 

825 identity(user) # (1,) 

826 inspect(user).identity # None 

827 

828 session.commit() 

829 

830 identity(user) # (1,) 

831 inspect(user).identity # (1, ) 

832 

833 

834 You can also use identity for classes:: 

835 

836 

837 identity(User) # (User.id, ) 

838 

839 .. versionadded: 0.21.0 

840 

841 :param obj: SQLAlchemy declarative model object 

842 """ 

843 return tuple( 

844 getattr(obj_or_class, column_key) 

845 for column_key in get_primary_keys(obj_or_class).keys() 

846 ) 

847 

848 

849def naturally_equivalent(obj, obj2): 

850 """ 

851 Returns whether or not two given SQLAlchemy declarative instances are 

852 naturally equivalent (all their non primary key properties are equivalent). 

853 

854 

855 :: 

856 

857 from sqlalchemy_utils import naturally_equivalent 

858 

859 

860 user = User(name='someone') 

861 user2 = User(name='someone') 

862 

863 user == user2 # False 

864 

865 naturally_equivalent(user, user2) # True 

866 

867 

868 :param obj: SQLAlchemy declarative model object 

869 :param obj2: SQLAlchemy declarative model object to compare with `obj` 

870 """ 

871 for column_key, column in sa.inspect(obj.__class__).columns.items(): 

872 if column.primary_key: 

873 continue 

874 

875 if not (getattr(obj, column_key) == getattr(obj2, column_key)): 

876 return False 

877 return True 

878 

879 

880def _get_class_registry(class_): 

881 try: 

882 return class_.registry._class_registry 

883 except AttributeError: # SQLAlchemy <1.4 

884 return class_._decl_class_registry