1from collections import defaultdict
2from itertools import groupby
3
4import sqlalchemy as sa
5from sqlalchemy.exc import NoInspectionAvailable
6from sqlalchemy.orm import object_session
7from sqlalchemy.schema import ForeignKeyConstraint, MetaData, Table
8
9from ..query_chain import QueryChain
10from .database import has_index
11from .orm import _get_class_registry, get_column_key, get_mapper, get_tables
12
13
14def get_foreign_key_values(fk, obj):
15 mapper = get_mapper(obj)
16 return {
17 fk.constraint.columns.values()[index]: getattr(obj, element.column.key)
18 if hasattr(obj, element.column.key)
19 else getattr(obj, mapper.get_property_by_column(element.column).key)
20 for index, element in enumerate(fk.constraint.elements)
21 }
22
23
24def group_foreign_keys(foreign_keys):
25 """
26 Return a groupby iterator that groups given foreign keys by table.
27
28 :param foreign_keys: a sequence of foreign keys
29
30
31 ::
32
33 foreign_keys = get_referencing_foreign_keys(User)
34
35 for table, fks in group_foreign_keys(foreign_keys):
36 # do something
37 pass
38
39
40 .. seealso:: :func:`get_referencing_foreign_keys`
41
42 .. versionadded: 0.26.1
43 """
44 foreign_keys = sorted(foreign_keys, key=lambda key: key.constraint.table.name)
45 return groupby(foreign_keys, lambda key: key.constraint.table)
46
47
48def get_referencing_foreign_keys(mixed):
49 """
50 Returns referencing foreign keys for given Table object or declarative
51 class.
52
53 :param mixed:
54 SA Table object or SA declarative class
55
56 ::
57
58 get_referencing_foreign_keys(User) # set([ForeignKey('user.id')])
59
60 get_referencing_foreign_keys(User.__table__)
61
62
63 This function also understands inheritance. This means it returns
64 all foreign keys that reference any table in the class inheritance tree.
65
66 Let's say you have three classes which use joined table inheritance,
67 namely TextItem, Article and BlogPost with Article and BlogPost inheriting
68 TextItem.
69
70 ::
71
72 # This will check all foreign keys that reference either article table
73 # or textitem table.
74 get_referencing_foreign_keys(Article)
75
76 .. seealso:: :func:`get_tables`
77 """
78 if isinstance(mixed, sa.Table):
79 tables = [mixed]
80 else:
81 tables = get_tables(mixed)
82
83 referencing_foreign_keys = set()
84
85 for table in mixed.metadata.tables.values():
86 if table not in tables:
87 for constraint in table.constraints:
88 if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
89 for fk in constraint.elements:
90 if any(fk.references(t) for t in tables):
91 referencing_foreign_keys.add(fk)
92 return referencing_foreign_keys
93
94
95def merge_references(from_, to, foreign_keys=None):
96 """
97 Merge the references of an entity into another entity.
98
99 Consider the following models::
100
101 class User(self.Base):
102 __tablename__ = 'user'
103 id = sa.Column(sa.Integer, primary_key=True)
104 name = sa.Column(sa.String(255))
105
106 def __repr__(self):
107 return 'User(name=%r)' % self.name
108
109 class BlogPost(self.Base):
110 __tablename__ = 'blog_post'
111 id = sa.Column(sa.Integer, primary_key=True)
112 title = sa.Column(sa.String(255))
113 author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
114
115 author = sa.orm.relationship(User)
116
117
118 Now lets add some data::
119
120 john = self.User(name='John')
121 jack = self.User(name='Jack')
122 post = self.BlogPost(title='Some title', author=john)
123 post2 = self.BlogPost(title='Other title', author=jack)
124 self.session.add_all([
125 john,
126 jack,
127 post,
128 post2
129 ])
130 self.session.commit()
131
132
133 If we wanted to merge all John's references to Jack it would be as easy as
134 ::
135
136 merge_references(john, jack)
137 self.session.commit()
138
139 post.author # User(name='Jack')
140 post2.author # User(name='Jack')
141
142
143
144 :param from_: an entity to merge into another entity
145 :param to: an entity to merge another entity into
146 :param foreign_keys: A sequence of foreign keys. By default this is None
147 indicating all referencing foreign keys should be used.
148
149 .. seealso: :func:`dependent_objects`
150
151 .. versionadded: 0.26.1
152
153 .. versionchanged: 0.40.0
154
155 Removed possibility for old-style synchronize_session merging. Only
156 SQL based merging supported for now.
157 """
158 if from_.__tablename__ != to.__tablename__:
159 raise TypeError('The tables of given arguments do not match.')
160
161 session = object_session(from_)
162 foreign_keys = get_referencing_foreign_keys(from_)
163
164 for fk in foreign_keys:
165 old_values = get_foreign_key_values(fk, from_)
166 new_values = get_foreign_key_values(fk, to)
167 criteria = (
168 getattr(fk.constraint.table.c, key.key) == value
169 for key, value in old_values.items()
170 )
171 query = (
172 fk.constraint.table.update()
173 .where(sa.and_(*criteria))
174 .values({key.key: value for key, value in new_values.items()})
175 )
176 session.execute(query)
177
178
179def dependent_objects(obj, foreign_keys=None):
180 """
181 Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates
182 through all dependent objects for given SQLAlchemy object.
183
184 Consider a User object is referenced in various articles and also in
185 various orders. Getting all these dependent objects is as easy as::
186
187 from sqlalchemy_utils import dependent_objects
188
189
190 dependent_objects(user)
191
192
193 If you expect an object to have lots of dependent_objects it might be good
194 to limit the results::
195
196
197 dependent_objects(user).limit(5)
198
199
200
201 The common use case is checking for all restrict dependent objects before
202 deleting parent object and inform the user if there are dependent objects
203 with ondelete='RESTRICT' foreign keys. If this kind of checking is not used
204 it will lead to nasty IntegrityErrors being raised.
205
206 In the following example we delete given user if it doesn't have any
207 foreign key restricted dependent objects::
208
209
210 from sqlalchemy_utils import get_referencing_foreign_keys
211
212
213 user = session.query(User).get(some_user_id)
214
215
216 deps = list(
217 dependent_objects(
218 user,
219 (
220 fk for fk in get_referencing_foreign_keys(User)
221 # On most databases RESTRICT is the default mode hence we
222 # check for None values also
223 if fk.ondelete == 'RESTRICT' or fk.ondelete is None
224 )
225 ).limit(5)
226 )
227
228 if deps:
229 # Do something to inform the user
230 pass
231 else:
232 session.delete(user)
233
234
235 :param obj: SQLAlchemy declarative model object
236 :param foreign_keys:
237 A sequence of foreign keys to use for searching the dependent_objects
238 for given object. By default this is None, indicating that all foreign
239 keys referencing the object will be used.
240
241 .. note::
242 This function does not support exotic mappers that use multiple tables
243
244 .. seealso:: :func:`get_referencing_foreign_keys`
245 .. seealso:: :func:`merge_references`
246
247 .. versionadded: 0.26.0
248 """
249 if foreign_keys is None:
250 foreign_keys = get_referencing_foreign_keys(obj)
251
252 session = object_session(obj)
253
254 chain = QueryChain([])
255 classes = _get_class_registry(obj.__class__)
256
257 for table, keys in group_foreign_keys(foreign_keys):
258 keys = list(keys)
259 for class_ in classes.values():
260 try:
261 mapper = sa.inspect(class_)
262 except NoInspectionAvailable:
263 continue
264 parent_mapper = mapper.inherits
265 if table in mapper.tables and not (
266 parent_mapper and table in parent_mapper.tables
267 ):
268 query = session.query(class_).filter(
269 sa.or_(*_get_criteria(keys, class_, obj))
270 )
271 chain.queries.append(query)
272 return chain
273
274
275def _get_criteria(keys, class_, obj):
276 criteria = []
277 visited_constraints = []
278 for key in keys:
279 if key.constraint in visited_constraints:
280 continue
281 visited_constraints.append(key.constraint)
282
283 subcriteria = []
284 for index, column in enumerate(key.constraint.columns):
285 foreign_column = key.constraint.elements[index].column
286 subcriteria.append(
287 getattr(class_, get_column_key(class_, column))
288 == getattr(
289 obj,
290 sa.inspect(type(obj)).get_property_by_column(foreign_column).key,
291 )
292 )
293 criteria.append(sa.and_(*subcriteria))
294 return criteria
295
296
297def non_indexed_foreign_keys(metadata, engine=None):
298 """
299 Finds all non indexed foreign keys from all tables of given MetaData.
300
301 Very useful for optimizing postgresql database and finding out which
302 foreign keys need indexes.
303
304 :param metadata: MetaData object to inspect tables from
305 """
306 reflected_metadata = MetaData()
307
308 bind = getattr(metadata, 'bind', None)
309 if bind is None and engine is None:
310 raise Exception(
311 'Either pass a metadata object with bind or '
312 'pass engine as a second parameter'
313 )
314
315 constraints = defaultdict(list)
316
317 for table_name in metadata.tables.keys():
318 table = Table(table_name, reflected_metadata, autoload_with=bind or engine)
319
320 for constraint in table.constraints:
321 if not isinstance(constraint, ForeignKeyConstraint):
322 continue
323
324 if not has_index(constraint):
325 constraints[table.name].append(constraint)
326
327 return dict(constraints)
328
329
330def get_fk_constraint_for_columns(table, *columns):
331 for constraint in table.constraints:
332 if list(constraint.columns.values()) == list(columns):
333 return constraint