1from collections import OrderedDict
2from functools import partial
3from inspect import isclass
4from operator import attrgetter
5
6import sqlalchemy as sa
7from sqlalchemy.engine.interfaces import Dialect
8from sqlalchemy.ext.hybrid import hybrid_property
9from sqlalchemy.orm import ColumnProperty, mapperlib, RelationshipProperty
10from sqlalchemy.orm.attributes import InstrumentedAttribute
11from sqlalchemy.orm.exc import UnmappedInstanceError
12
13try:
14 from sqlalchemy.orm.context import _ColumnEntity, _MapperEntity
15except ImportError: # SQLAlchemy <1.4
16 from sqlalchemy.orm.query import _ColumnEntity, _MapperEntity
17
18from sqlalchemy.orm.session import object_session
19from sqlalchemy.orm.util import AliasedInsp
20
21from ..utils import is_sequence
22
23
24def get_class_by_table(base, table, data=None):
25 """
26 Return declarative class associated with given table. If no class is found
27 this function returns `None`. If multiple classes were found (polymorphic
28 cases) additional `data` parameter can be given to hint which class
29 to return.
30
31 ::
32
33 class User(Base):
34 __tablename__ = 'entity'
35 id = sa.Column(sa.Integer, primary_key=True)
36 name = sa.Column(sa.String)
37
38
39 get_class_by_table(Base, User.__table__) # User class
40
41
42 This function also supports models using single table inheritance.
43 Additional data paratemer should be provided in these case.
44
45 ::
46
47 class Entity(Base):
48 __tablename__ = 'entity'
49 id = sa.Column(sa.Integer, primary_key=True)
50 name = sa.Column(sa.String)
51 type = sa.Column(sa.String)
52 __mapper_args__ = {
53 'polymorphic_on': type,
54 'polymorphic_identity': 'entity'
55 }
56
57 class User(Entity):
58 __mapper_args__ = {
59 'polymorphic_identity': 'user'
60 }
61
62
63 # Entity class
64 get_class_by_table(Base, Entity.__table__, {'type': 'entity'})
65
66 # User class
67 get_class_by_table(Base, Entity.__table__, {'type': 'user'})
68
69
70 :param base: Declarative model base
71 :param table: SQLAlchemy Table object
72 :param data: Data row to determine the class in polymorphic scenarios
73 :return: Declarative class or None.
74 """
75 found_classes = {
76 c
77 for c in _get_class_registry(base).values()
78 if hasattr(c, '__table__') and c.__table__ is table
79 }
80 if len(found_classes) > 1:
81 if not data:
82 raise ValueError(
83 "Multiple declarative classes found for table '{}'. "
84 'Please provide data parameter for this function to be able '
85 'to determine polymorphic scenarios.'.format(table.name)
86 )
87 else:
88 for cls in found_classes:
89 mapper = sa.inspect(cls)
90 polymorphic_on = mapper.polymorphic_on.name
91 if polymorphic_on in data:
92 if data[polymorphic_on] == mapper.polymorphic_identity:
93 return cls
94 raise ValueError(
95 "Multiple declarative classes found for table '{}'. Given "
96 'data row does not match any polymorphic identity of the '
97 'found classes.'.format(table.name)
98 )
99 elif found_classes:
100 return found_classes.pop()
101 return None
102
103
104def get_type(expr):
105 """
106 Return the associated type with given Column, InstrumentedAttribute,
107 ColumnProperty, RelationshipProperty or other similar SQLAlchemy construct.
108
109 For constructs wrapping columns this is the column type. For relationships
110 this function returns the relationship mapper class.
111
112 :param expr:
113 SQLAlchemy Column, InstrumentedAttribute, ColumnProperty or other
114 similar SA construct.
115
116 ::
117
118 class User(Base):
119 __tablename__ = 'user'
120 id = sa.Column(sa.Integer, primary_key=True)
121 name = sa.Column(sa.String)
122
123
124 class Article(Base):
125 __tablename__ = 'article'
126 id = sa.Column(sa.Integer, primary_key=True)
127 author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
128 author = sa.orm.relationship(User)
129
130
131 get_type(User.__table__.c.name) # sa.String()
132 get_type(User.name) # sa.String()
133 get_type(User.name.property) # sa.String()
134
135 get_type(Article.author) # User
136
137
138 .. versionadded: 0.30.9
139 """
140 if hasattr(expr, 'type'):
141 return expr.type
142 elif isinstance(expr, InstrumentedAttribute):
143 expr = expr.property
144
145 if isinstance(expr, ColumnProperty):
146 return expr.columns[0].type
147 elif isinstance(expr, RelationshipProperty):
148 return expr.mapper.class_
149 raise TypeError("Couldn't inspect type.")
150
151
152def cast_if(expression, type_):
153 """
154 Produce a CAST expression but only if given expression is not of given type
155 already.
156
157 Assume we have a model with two fields id (Integer) and name (String).
158
159 ::
160
161 import sqlalchemy as sa
162 from sqlalchemy_utils import cast_if
163
164
165 cast_if(User.id, sa.Integer) # "user".id
166 cast_if(User.name, sa.String) # "user".name
167 cast_if(User.id, sa.String) # CAST("user".id AS TEXT)
168
169
170 This function supports scalar values as well.
171
172 ::
173
174 cast_if(1, sa.Integer) # 1
175 cast_if('text', sa.String) # 'text'
176 cast_if(1, sa.String) # CAST(1 AS TEXT)
177
178
179 :param expression:
180 A SQL expression, such as a ColumnElement expression or a Python string
181 which will be coerced into a bound literal value.
182 :param type_:
183 A TypeEngine class or instance indicating the type to which the CAST
184 should apply.
185
186 .. versionadded: 0.30.14
187 """
188 try:
189 expr_type = get_type(expression)
190 except TypeError:
191 expr_type = expression
192 check_type = type_().python_type
193 else:
194 check_type = type_
195
196 return (
197 sa.cast(expression, type_)
198 if not isinstance(expr_type, check_type)
199 else expression
200 )
201
202
203def get_column_key(model, column):
204 """
205 Return the key for given column in given model.
206
207 :param model: SQLAlchemy declarative model object
208
209 ::
210
211 class User(Base):
212 __tablename__ = 'user'
213 id = sa.Column(sa.Integer, primary_key=True)
214 name = sa.Column('_name', sa.String)
215
216
217 get_column_key(User, User.__table__.c._name) # 'name'
218
219 .. versionadded: 0.26.5
220
221 .. versionchanged: 0.27.11
222 Throws UnmappedColumnError instead of ValueError when no property was
223 found for given column. This is consistent with how SQLAlchemy works.
224 """
225 mapper = sa.inspect(model)
226 try:
227 return mapper.get_property_by_column(column).key
228 except sa.orm.exc.UnmappedColumnError:
229 for key, c in mapper.columns.items():
230 if c.name == column.name and c.table is column.table:
231 return key
232 raise sa.orm.exc.UnmappedColumnError(
233 f'No column {column} is configured on mapper {mapper}...'
234 )
235
236
237def get_mapper(mixed):
238 """
239 Return related SQLAlchemy Mapper for given SQLAlchemy object.
240
241 :param mixed: SQLAlchemy Table / Alias / Mapper / declarative model object
242
243 ::
244
245 from sqlalchemy_utils import get_mapper
246
247
248 get_mapper(User)
249
250 get_mapper(User())
251
252 get_mapper(User.__table__)
253
254 get_mapper(User.__mapper__)
255
256 get_mapper(sa.orm.aliased(User))
257
258 get_mapper(sa.orm.aliased(User.__table__))
259
260
261 Raises:
262 ValueError: if multiple mappers were found for given argument
263
264 .. versionadded: 0.26.1
265 """
266 if isinstance(mixed, _MapperEntity):
267 mixed = mixed.expr
268 elif isinstance(mixed, sa.Column):
269 mixed = mixed.table
270 elif isinstance(mixed, _ColumnEntity):
271 mixed = mixed.expr
272
273 if isinstance(mixed, sa.orm.Mapper):
274 return mixed
275 if isinstance(mixed, sa.orm.util.AliasedClass):
276 return sa.inspect(mixed).mapper
277 if isinstance(mixed, sa.sql.selectable.Alias):
278 mixed = mixed.element
279 if isinstance(mixed, AliasedInsp):
280 return mixed.mapper
281 if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
282 mixed = mixed.class_
283 if isinstance(mixed, sa.Table):
284 if hasattr(mapperlib, '_all_registries'):
285 all_mappers = set()
286 for mapper_registry in mapperlib._all_registries():
287 all_mappers.update(mapper_registry.mappers)
288 else: # SQLAlchemy <1.4
289 all_mappers = mapperlib._mapper_registry
290 mappers = [mapper for mapper in all_mappers if mixed in mapper.tables]
291 if len(mappers) > 1:
292 raise ValueError("Multiple mappers found for table '%s'." % mixed.name)
293 elif not mappers:
294 raise ValueError("Could not get mapper for table '%s'." % mixed.name)
295 else:
296 return mappers[0]
297 if not isclass(mixed):
298 mixed = type(mixed)
299 return sa.inspect(mixed)
300
301
302def get_bind(obj):
303 """
304 Return the bind for given SQLAlchemy Engine / Connection / declarative
305 model object.
306
307 :param obj: SQLAlchemy Engine / Connection / declarative model object
308
309 ::
310
311 from sqlalchemy_utils import get_bind
312
313
314 get_bind(session) # Connection object
315
316 get_bind(user)
317
318 """
319 if hasattr(obj, 'bind'):
320 conn = obj.bind
321 else:
322 try:
323 conn = object_session(obj).bind
324 except UnmappedInstanceError:
325 conn = obj
326
327 if not hasattr(conn, 'execute'):
328 raise TypeError(
329 'This method accepts only Session, Engine, Connection and '
330 'declarative model objects.'
331 )
332 return conn
333
334
335def get_primary_keys(mixed):
336 """
337 Return an OrderedDict of all primary keys for given Table object,
338 declarative class or declarative class instance.
339
340 :param mixed:
341 SA Table object, SA declarative class or SA declarative class instance
342
343 ::
344
345 get_primary_keys(User)
346
347 get_primary_keys(User())
348
349 get_primary_keys(User.__table__)
350
351 get_primary_keys(User.__mapper__)
352
353 get_primary_keys(sa.orm.aliased(User))
354
355 get_primary_keys(sa.orm.aliased(User.__table__))
356
357
358 .. versionchanged: 0.25.3
359 Made the function return an ordered dictionary instead of generator.
360 This change was made to support primary key aliases.
361
362 Renamed this function to 'get_primary_keys', formerly 'primary_keys'
363
364 .. seealso:: :func:`get_columns`
365 """
366 return OrderedDict(
367 (
368 (key, column)
369 for key, column in get_columns(mixed).items()
370 if column.primary_key
371 )
372 )
373
374
375def get_tables(mixed):
376 """
377 Return a set of tables associated with given SQLAlchemy object.
378
379 Let's say we have three classes which use joined table inheritance
380 TextItem, Article and BlogPost. Article and BlogPost inherit TextItem.
381
382 ::
383
384 get_tables(Article) # set([Table('article', ...), Table('text_item')])
385
386 get_tables(Article())
387
388 get_tables(Article.__mapper__)
389
390
391 If the TextItem entity is using with_polymorphic='*' then this function
392 returns all child tables (article and blog_post) as well.
393
394 ::
395
396
397 get_tables(TextItem) # set([Table('text_item', ...)], ...])
398
399
400 .. versionadded: 0.26.0
401
402 :param mixed:
403 SQLAlchemy Mapper, Declarative class, Column, InstrumentedAttribute or
404 a SA Alias object wrapping any of these objects.
405 """
406 if isinstance(mixed, sa.Table):
407 return [mixed]
408 elif isinstance(mixed, sa.Column):
409 return [mixed.table]
410 elif isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
411 return mixed.parent.tables
412 elif isinstance(mixed, _ColumnEntity):
413 mixed = mixed.expr
414
415 mapper = get_mapper(mixed)
416
417 polymorphic_mappers = get_polymorphic_mappers(mapper)
418 if polymorphic_mappers:
419 tables = sum((m.tables for m in polymorphic_mappers), [])
420 else:
421 tables = mapper.tables
422 return tables
423
424
425def get_columns(mixed):
426 """
427 Return a collection of all Column objects for given SQLAlchemy
428 object.
429
430 The type of the collection depends on the type of the object to return the
431 columns from.
432
433 ::
434
435 get_columns(User)
436
437 get_columns(User())
438
439 get_columns(User.__table__)
440
441 get_columns(User.__mapper__)
442
443 get_columns(sa.orm.aliased(User))
444
445 get_columns(sa.orm.alised(User.__table__))
446
447
448 :param mixed:
449 SA Table object, SA Mapper, SA declarative class, SA declarative class
450 instance or an alias of any of these objects
451 """
452 if isinstance(mixed, sa.sql.selectable.Selectable):
453 try:
454 return mixed.selected_columns
455 except AttributeError: # SQLAlchemy <1.4
456 return mixed.c
457 if isinstance(mixed, sa.orm.util.AliasedClass):
458 return sa.inspect(mixed).mapper.columns
459 if isinstance(mixed, sa.orm.Mapper):
460 return mixed.columns
461 if isinstance(mixed, InstrumentedAttribute):
462 return mixed.property.columns
463 if isinstance(mixed, ColumnProperty):
464 return mixed.columns
465 if isinstance(mixed, sa.Column):
466 return [mixed]
467 if not isclass(mixed):
468 mixed = mixed.__class__
469 return sa.inspect(mixed).columns
470
471
472def table_name(obj):
473 """
474 Return table name of given target, declarative class or the
475 table name where the declarative attribute is bound to.
476 """
477 class_ = getattr(obj, 'class_', obj)
478
479 try:
480 return class_.__tablename__
481 except AttributeError:
482 pass
483
484 try:
485 return class_.__table__.name
486 except AttributeError:
487 pass
488
489
490def getattrs(obj, attrs):
491 return map(partial(getattr, obj), attrs)
492
493
494def quote(mixed, ident):
495 """
496 Conditionally quote an identifier.
497 ::
498
499
500 from sqlalchemy_utils import quote
501
502
503 engine = create_engine('sqlite:///:memory:')
504
505 quote(engine, 'order')
506 # '"order"'
507
508 quote(engine, 'some_other_identifier')
509 # 'some_other_identifier'
510
511
512 :param mixed: SQLAlchemy Session / Connection / Engine / Dialect object.
513 :param ident: identifier to conditionally quote
514 """
515 if isinstance(mixed, Dialect):
516 dialect = mixed
517 else:
518 dialect = get_bind(mixed).dialect
519 return dialect.preparer(dialect).quote(ident)
520
521
522def _get_query_compile_state(query):
523 if hasattr(query, '_compile_state'):
524 return query._compile_state()
525 else: # SQLAlchemy <1.4
526 return query
527
528
529def get_polymorphic_mappers(mixed):
530 if isinstance(mixed, AliasedInsp):
531 return mixed.with_polymorphic_mappers
532 else:
533 return mixed.polymorphic_map.values()
534
535
536def get_descriptor(entity, attr):
537 mapper = sa.inspect(entity)
538
539 for key, descriptor in get_all_descriptors(mapper).items():
540 if attr == key:
541 prop = descriptor.property if hasattr(descriptor, 'property') else None
542 if isinstance(prop, ColumnProperty):
543 if isinstance(entity, sa.orm.util.AliasedClass):
544 for c in mapper.selectable.c:
545 if c.key == attr:
546 return c
547 else:
548 # If the property belongs to a class that uses
549 # polymorphic inheritance we have to take into account
550 # situations where the attribute exists in child class
551 # but not in parent class.
552 return getattr(prop.parent.class_, attr)
553 else:
554 # Handle synonyms, relationship properties and hybrid
555 # properties
556
557 if isinstance(entity, sa.orm.util.AliasedClass):
558 return getattr(entity, attr)
559 try:
560 return getattr(mapper.class_, attr)
561 except AttributeError:
562 pass
563
564
565def get_all_descriptors(expr):
566 if isinstance(expr, sa.sql.selectable.Selectable):
567 return expr.c
568 insp = sa.inspect(expr)
569 try:
570 polymorphic_mappers = get_polymorphic_mappers(insp)
571 except sa.exc.NoInspectionAvailable:
572 return get_mapper(expr).all_orm_descriptors
573 else:
574 attrs = dict(get_mapper(expr).all_orm_descriptors)
575 for submapper in polymorphic_mappers:
576 for key, descriptor in submapper.all_orm_descriptors.items():
577 if key not in attrs:
578 attrs[key] = descriptor
579 return attrs
580
581
582def get_hybrid_properties(model):
583 """
584 Returns a dictionary of hybrid property keys and hybrid properties for
585 given SQLAlchemy declarative model / mapper.
586
587
588 Consider the following model
589
590 ::
591
592
593 from sqlalchemy.ext.hybrid import hybrid_property
594
595
596 class Category(Base):
597 __tablename__ = 'category'
598 id = sa.Column(sa.Integer, primary_key=True)
599 name = sa.Column(sa.Unicode(255))
600
601 @hybrid_property
602 def lowercase_name(self):
603 return self.name.lower()
604
605 @lowercase_name.expression
606 def lowercase_name(cls):
607 return sa.func.lower(cls.name)
608
609
610 You can now easily get a list of all hybrid property names
611
612 ::
613
614
615 from sqlalchemy_utils import get_hybrid_properties
616
617
618 get_hybrid_properties(Category).keys() # ['lowercase_name']
619
620
621 This function also supports aliased classes
622
623 ::
624
625
626 get_hybrid_properties(
627 sa.orm.aliased(Category)
628 ).keys() # ['lowercase_name']
629
630
631 .. versionchanged: 0.26.7
632 This function now returns a dictionary instead of generator
633
634 .. versionchanged: 0.30.15
635 Added support for aliased classes
636
637 :param model: SQLAlchemy declarative model or mapper
638 """
639 return {
640 key: prop
641 for key, prop in get_mapper(model).all_orm_descriptors.items()
642 if isinstance(prop, hybrid_property)
643 }
644
645
646def get_declarative_base(model):
647 """
648 Returns the declarative base for given model class.
649
650 :param model: SQLAlchemy declarative model
651 """
652 for parent in model.__bases__:
653 try:
654 parent.metadata
655 return get_declarative_base(parent)
656 except AttributeError:
657 pass
658 return model
659
660
661def getdotattr(obj_or_class, dot_path, condition=None):
662 """
663 Allow dot-notated strings to be passed to `getattr`.
664
665 ::
666
667 getdotattr(SubSection, 'section.document')
668
669 getdotattr(subsection, 'section.document')
670
671
672 :param obj_or_class: Any object or class
673 :param dot_path: Attribute path with dot mark as separator
674 """
675 last = obj_or_class
676
677 for path in str(dot_path).split('.'):
678 getter = attrgetter(path)
679
680 if is_sequence(last):
681 tmp = []
682 for element in last:
683 value = getter(element)
684 if is_sequence(value):
685 tmp.extend(value)
686 else:
687 tmp.append(value)
688 last = tmp
689 elif isinstance(last, InstrumentedAttribute):
690 last = getter(last.property.mapper.class_)
691 elif last is None:
692 return None
693 else:
694 last = getter(last)
695 if condition is not None:
696 if is_sequence(last):
697 last = [v for v in last if condition(v)]
698 else:
699 if not condition(last):
700 return None
701
702 return last
703
704
705def is_deleted(obj):
706 return obj in sa.orm.object_session(obj).deleted
707
708
709def has_changes(obj, attrs=None, exclude=None):
710 """
711 Simple shortcut function for checking if given attributes of given
712 declarative model object have changed during the session. Without
713 parameters this checks if given object has any modificiations. Additionally
714 exclude parameter can be given to check if given object has any changes
715 in any attributes other than the ones given in exclude.
716
717
718 ::
719
720
721 from sqlalchemy_utils import has_changes
722
723
724 user = User()
725
726 has_changes(user, 'name') # False
727
728 user.name = 'someone'
729
730 has_changes(user, 'name') # True
731
732 has_changes(user) # True
733
734
735 You can check multiple attributes as well.
736 ::
737
738
739 has_changes(user, ['age']) # True
740
741 has_changes(user, ['name', 'age']) # True
742
743
744 This function also supports excluding certain attributes.
745
746 ::
747
748 has_changes(user, exclude=['name']) # False
749
750 has_changes(user, exclude=['age']) # True
751
752 .. versionchanged: 0.26.6
753 Added support for multiple attributes and exclude parameter.
754
755 :param obj: SQLAlchemy declarative model object
756 :param attrs: Names of the attributes
757 :param exclude: Names of the attributes to exclude
758 """
759 if attrs:
760 if isinstance(attrs, str):
761 return sa.inspect(obj).attrs.get(attrs).history.has_changes()
762 else:
763 return any(has_changes(obj, attr) for attr in attrs)
764 else:
765 if exclude is None:
766 exclude = []
767 return any(
768 attr.history.has_changes()
769 for key, attr in sa.inspect(obj).attrs.items()
770 if key not in exclude
771 )
772
773
774def is_loaded(obj, prop):
775 """
776 Return whether or not given property of given object has been loaded.
777
778 ::
779
780 class Article(Base):
781 __tablename__ = 'article'
782 id = sa.Column(sa.Integer, primary_key=True)
783 name = sa.Column(sa.String)
784 content = sa.orm.deferred(sa.Column(sa.String))
785
786
787 article = session.query(Article).get(5)
788
789 # name gets loaded since its not a deferred property
790 assert is_loaded(article, 'name')
791
792 # content has not yet been loaded since its a deferred property
793 assert not is_loaded(article, 'content')
794
795
796 .. versionadded: 0.27.8
797
798 :param obj: SQLAlchemy declarative model object
799 :param prop: Name of the property or InstrumentedAttribute
800 """
801 return prop not in sa.inspect(obj).unloaded
802
803
804def identity(obj_or_class):
805 """
806 Return the identity of given sqlalchemy declarative model class or instance
807 as a tuple. This differs from obj._sa_instance_state.identity in a way that
808 it always returns the identity even if object is still in transient state (
809 new object that is not yet persisted into database). Also for classes it
810 returns the identity attributes.
811
812 ::
813
814 from sqlalchemy import inspect
815 from sqlalchemy_utils import identity
816
817
818 user = User(name='John Matrix')
819 session.add(user)
820 identity(user) # None
821 inspect(user).identity # None
822
823 session.flush() # User now has id but is still in transient state
824
825 identity(user) # (1,)
826 inspect(user).identity # None
827
828 session.commit()
829
830 identity(user) # (1,)
831 inspect(user).identity # (1, )
832
833
834 You can also use identity for classes::
835
836
837 identity(User) # (User.id, )
838
839 .. versionadded: 0.21.0
840
841 :param obj: SQLAlchemy declarative model object
842 """
843 return tuple(
844 getattr(obj_or_class, column_key)
845 for column_key in get_primary_keys(obj_or_class).keys()
846 )
847
848
849def naturally_equivalent(obj, obj2):
850 """
851 Returns whether or not two given SQLAlchemy declarative instances are
852 naturally equivalent (all their non primary key properties are equivalent).
853
854
855 ::
856
857 from sqlalchemy_utils import naturally_equivalent
858
859
860 user = User(name='someone')
861 user2 = User(name='someone')
862
863 user == user2 # False
864
865 naturally_equivalent(user, user2) # True
866
867
868 :param obj: SQLAlchemy declarative model object
869 :param obj2: SQLAlchemy declarative model object to compare with `obj`
870 """
871 for column_key, column in sa.inspect(obj.__class__).columns.items():
872 if column.primary_key:
873 continue
874
875 if not (getattr(obj, column_key) == getattr(obj2, column_key)):
876 return False
877 return True
878
879
880def _get_class_registry(class_):
881 try:
882 return class_.registry._class_registry
883 except AttributeError: # SQLAlchemy <1.4
884 return class_._decl_class_registry