1"""
2SQLAlchemy-Utils provides way of automatically calculating aggregate values of
3related models and saving them to parent model.
4
5This solution is inspired by RoR counter cache,
6`counter_culture`_ and `stackoverflow reply by Michael Bayer`_.
7
8Why?
9----
10
11Many times you may have situations where you need to calculate dynamically some
12aggregate value for given model. Some simple examples include:
13
14- Number of products in a catalog
15- Average rating for movie
16- Latest forum post
17- Total price of orders for given customer
18
19Now all these aggregates can be elegantly implemented with SQLAlchemy
20column_property_ function. However when your data grows calculating these
21values on the fly might start to hurt the performance of your application. The
22more aggregates you are using the more performance penalty you get.
23
24This module provides way of calculating these values automatically and
25efficiently at the time of modification rather than on the fly.
26
27
28Features
29--------
30
31* Automatically updates aggregate columns when aggregated values change
32* Supports aggregate values through arbitrary number levels of relations
33* Highly optimized: uses single query per transaction per aggregate column
34* Aggregated columns can be of any data type and use any selectable scalar
35 expression
36
37
38.. _column_property:
39 https://docs.sqlalchemy.org/en/latest/orm/mapped_sql_expr.html#using-column-property
40.. _counter_culture: https://github.com/magnusvk/counter_culture
41.. _stackoverflow reply by Michael Bayer:
42 https://stackoverflow.com/a/13765857/520932
43
44
45Simple aggregates
46-----------------
47
48::
49
50 from sqlalchemy_utils import aggregated
51
52
53 class Thread(Base):
54 __tablename__ = 'thread'
55 id = sa.Column(sa.Integer, primary_key=True)
56 name = sa.Column(sa.Unicode(255))
57
58 @aggregated('comments', sa.Column(sa.Integer))
59 def comment_count(self):
60 return sa.func.count('1')
61
62 comments = sa.orm.relationship(
63 'Comment',
64 backref='thread'
65 )
66
67
68 class Comment(Base):
69 __tablename__ = 'comment'
70 id = sa.Column(sa.Integer, primary_key=True)
71 content = sa.Column(sa.UnicodeText)
72 thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
73
74
75 thread = Thread(name='SQLAlchemy development')
76 thread.comments.append(Comment('Going good!'))
77 thread.comments.append(Comment('Great new features!'))
78
79 session.add(thread)
80 session.commit()
81
82 thread.comment_count # 2
83
84
85
86Custom aggregate expressions
87----------------------------
88
89Aggregate expression can be virtually any SQL expression not just a simple
90function taking one parameter. You can try things such as subqueries and
91different kinds of functions.
92
93In the following example we have a Catalog of products where each catalog
94knows the net worth of its products.
95
96::
97
98
99 from sqlalchemy_utils import aggregated
100
101
102 class Catalog(Base):
103 __tablename__ = 'catalog'
104 id = sa.Column(sa.Integer, primary_key=True)
105 name = sa.Column(sa.Unicode(255))
106
107 @aggregated('products', sa.Column(sa.Integer))
108 def net_worth(self):
109 return sa.func.sum(Product.price)
110
111 products = sa.orm.relationship('Product')
112
113
114 class Product(Base):
115 __tablename__ = 'product'
116 id = sa.Column(sa.Integer, primary_key=True)
117 name = sa.Column(sa.Unicode(255))
118 price = sa.Column(sa.Numeric)
119
120 catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
121
122
123Now the net_worth column of Catalog model will be automatically whenever:
124
125* A new product is added to the catalog
126* A product is deleted from the catalog
127* The price of catalog product is changed
128
129
130::
131
132
133 from decimal import Decimal
134
135
136 product1 = Product(name='Some product', price=Decimal(1000))
137 product2 = Product(name='Some other product', price=Decimal(500))
138
139
140 catalog = Catalog(
141 name='My first catalog',
142 products=[
143 product1,
144 product2
145 ]
146 )
147 session.add(catalog)
148 session.commit()
149
150 session.refresh(catalog)
151 catalog.net_worth # 1500
152
153 session.delete(product2)
154 session.commit()
155 session.refresh(catalog)
156
157 catalog.net_worth # 1000
158
159 product1.price = 2000
160 session.commit()
161 session.refresh(catalog)
162
163 catalog.net_worth # 2000
164
165
166
167
168Multiple aggregates per class
169-----------------------------
170
171Sometimes you may need to define multiple aggregate values for same class. If
172you need to define lots of relationships pointing to same class, remember to
173define the relationships as viewonly when possible.
174
175
176::
177
178
179 from sqlalchemy_utils import aggregated
180
181
182 class Customer(Base):
183 __tablename__ = 'customer'
184 id = sa.Column(sa.Integer, primary_key=True)
185 name = sa.Column(sa.Unicode(255))
186
187 @aggregated('orders', sa.Column(sa.Integer))
188 def orders_sum(self):
189 return sa.func.sum(Order.price)
190
191 @aggregated('invoiced_orders', sa.Column(sa.Integer))
192 def invoiced_orders_sum(self):
193 return sa.func.sum(Order.price)
194
195 orders = sa.orm.relationship('Order')
196
197 invoiced_orders = sa.orm.relationship(
198 'Order',
199 primaryjoin=
200 'sa.and_(Order.customer_id == Customer.id, Order.invoiced)',
201 viewonly=True
202 )
203
204
205 class Order(Base):
206 __tablename__ = 'order'
207 id = sa.Column(sa.Integer, primary_key=True)
208 name = sa.Column(sa.Unicode(255))
209 price = sa.Column(sa.Numeric)
210 invoiced = sa.Column(sa.Boolean, default=False)
211 customer_id = sa.Column(sa.Integer, sa.ForeignKey(Customer.id))
212
213
214Many-to-Many aggregates
215-----------------------
216
217Aggregate expressions also support many-to-many relationships. The usual use
218scenarios includes things such as:
219
2201. Friend count of a user
2212. Group count where given user belongs to
222
223::
224
225
226 user_group = sa.Table('user_group', Base.metadata,
227 sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
228 sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
229 )
230
231
232 class User(Base):
233 __tablename__ = 'user'
234 id = sa.Column(sa.Integer, primary_key=True)
235 name = sa.Column(sa.Unicode(255))
236
237 @aggregated('groups', sa.Column(sa.Integer, default=0))
238 def group_count(self):
239 return sa.func.count('1')
240
241 groups = sa.orm.relationship(
242 'Group',
243 backref='users',
244 secondary=user_group
245 )
246
247
248 class Group(Base):
249 __tablename__ = 'group'
250 id = sa.Column(sa.Integer, primary_key=True)
251 name = sa.Column(sa.Unicode(255))
252
253
254
255 user = User(name='John Matrix')
256 user.groups = [Group(name='Group A'), Group(name='Group B')]
257
258 session.add(user)
259 session.commit()
260
261 session.refresh(user)
262 user.group_count # 2
263
264
265Multi-level aggregates
266----------------------
267
268Aggregates can span across multiple relationships. In the following example
269each Catalog has a net_worth which is the sum of all products in all
270categories.
271
272
273::
274
275
276 from sqlalchemy_utils import aggregated
277
278
279 class Catalog(Base):
280 __tablename__ = 'catalog'
281 id = sa.Column(sa.Integer, primary_key=True)
282 name = sa.Column(sa.Unicode(255))
283
284 @aggregated('categories.products', sa.Column(sa.Integer))
285 def net_worth(self):
286 return sa.func.sum(Product.price)
287
288 categories = sa.orm.relationship('Category')
289
290
291 class Category(Base):
292 __tablename__ = 'category'
293 id = sa.Column(sa.Integer, primary_key=True)
294 name = sa.Column(sa.Unicode(255))
295
296 catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
297
298 products = sa.orm.relationship('Product')
299
300
301 class Product(Base):
302 __tablename__ = 'product'
303 id = sa.Column(sa.Integer, primary_key=True)
304 name = sa.Column(sa.Unicode(255))
305 price = sa.Column(sa.Numeric)
306
307 category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
308
309
310Examples
311--------
312
313Average movie rating
314^^^^^^^^^^^^^^^^^^^^
315
316::
317
318
319 from sqlalchemy_utils import aggregated
320
321
322 class Movie(Base):
323 __tablename__ = 'movie'
324 id = sa.Column(sa.Integer, primary_key=True)
325 name = sa.Column(sa.Unicode(255))
326
327 @aggregated('ratings', sa.Column(sa.Numeric))
328 def avg_rating(self):
329 return sa.func.avg(Rating.stars)
330
331 ratings = sa.orm.relationship('Rating')
332
333
334 class Rating(Base):
335 __tablename__ = 'rating'
336 id = sa.Column(sa.Integer, primary_key=True)
337 stars = sa.Column(sa.Integer)
338
339 movie_id = sa.Column(sa.Integer, sa.ForeignKey(Movie.id))
340
341
342 movie = Movie('Terminator 2')
343 movie.ratings.append(Rating(stars=5))
344 movie.ratings.append(Rating(stars=4))
345 movie.ratings.append(Rating(stars=3))
346 session.add(movie)
347 session.commit()
348
349 movie.avg_rating # 4
350
351
352
353TODO
354----
355
356* Special consideration should be given to `deadlocks`_.
357
358
359.. _deadlocks:
360 https://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
361
362"""
363
364from collections import defaultdict
365from weakref import WeakKeyDictionary
366
367import sqlalchemy as sa
368import sqlalchemy.event
369import sqlalchemy.orm
370from sqlalchemy.ext.declarative import declared_attr
371from sqlalchemy.sql.functions import _FunctionGenerator
372
373from .functions.orm import get_column_key
374from .relationships import (
375 chained_join,
376 path_to_relationships,
377 select_correlated_expression,
378)
379
380aggregated_attrs = WeakKeyDictionary()
381
382
383class AggregatedAttribute(declared_attr):
384 def __init__(self, fget, relationship, column, *args, **kwargs):
385 super().__init__(fget, *args, **kwargs)
386 self.__doc__ = fget.__doc__
387 self.column = column
388 self.relationship = relationship
389
390 def __get__(desc, self, cls):
391 value = (desc.fget, desc.relationship, desc.column)
392 if cls not in aggregated_attrs:
393 aggregated_attrs[cls] = [value]
394 else:
395 aggregated_attrs[cls].append(value)
396 return desc.column
397
398
399def local_condition(prop, objects):
400 pairs = prop.local_remote_pairs
401 if prop.secondary is not None:
402 parent_column = pairs[1][0]
403 fetched_column = pairs[1][0]
404 else:
405 parent_column = pairs[0][0]
406 fetched_column = pairs[0][1]
407
408 key = get_column_key(prop.mapper, fetched_column)
409
410 values = []
411 for obj in objects:
412 try:
413 values.append(getattr(obj, key))
414 except sa.orm.exc.ObjectDeletedError:
415 pass
416
417 if values:
418 return parent_column.in_(values)
419
420
421def aggregate_expression(expr, class_):
422 if isinstance(expr, sa.sql.visitors.Visitable):
423 return expr
424 elif isinstance(expr, _FunctionGenerator):
425 return expr(sa.sql.text('1'))
426 else:
427 return expr(class_)
428
429
430class AggregatedValue:
431 def __init__(self, class_, attr, path, expr):
432 self.class_ = class_
433 self.attr = attr
434 self.path = path
435 self.relationships = list(reversed(path_to_relationships(path, class_)))
436 self.expr = aggregate_expression(expr, class_)
437
438 @property
439 def aggregate_query(self):
440 query = select_correlated_expression(
441 self.class_, self.expr, self.path, self.relationships[0].mapper.class_
442 )
443
444 return query.scalar_subquery()
445
446 def update_query(self, objects):
447 table = self.class_.__table__
448 query = table.update().values({self.attr: self.aggregate_query})
449 if len(self.relationships) == 1:
450 prop = self.relationships[-1].property
451 condition = local_condition(prop, objects)
452 if condition is not None:
453 return query.where(condition)
454 else:
455 # Builds query such as:
456 #
457 # UPDATE catalog SET product_count = (aggregate_query)
458 # WHERE id IN (
459 # SELECT catalog_id
460 # FROM category
461 # INNER JOIN sub_category
462 # ON category.id = sub_category.category_id
463 # WHERE sub_category.id IN (product_sub_category_ids)
464 # )
465 property_ = self.relationships[-1].property
466 remote_pairs = property_.local_remote_pairs
467 local = remote_pairs[0][0]
468 remote = remote_pairs[0][1]
469 condition = local_condition(self.relationships[0].property, objects)
470 if condition is not None:
471 return query.where(
472 local.in_(
473 sa.select(remote)
474 .select_from(chained_join(*reversed(self.relationships)))
475 .where(condition)
476 )
477 )
478
479
480class AggregationManager:
481 def __init__(self):
482 self.reset()
483
484 def reset(self):
485 self.generator_registry = defaultdict(list)
486
487 def register_listeners(self):
488 sa.event.listen(
489 sa.orm.Mapper, 'after_configured', self.update_generator_registry
490 )
491 sa.event.listen(
492 sa.orm.session.Session, 'after_flush', self.construct_aggregate_queries
493 )
494
495 def update_generator_registry(self):
496 for class_, attrs in aggregated_attrs.items():
497 for expr, path, column in attrs:
498 value = AggregatedValue(
499 class_=class_, attr=column, path=path, expr=expr(class_)
500 )
501 key = value.relationships[0].mapper.class_
502 self.generator_registry[key].append(value)
503
504 def construct_aggregate_queries(self, session, ctx):
505 object_dict = defaultdict(list)
506 for obj in session:
507 for class_ in self.generator_registry:
508 if isinstance(obj, class_):
509 object_dict[class_].append(obj)
510
511 for class_, objects in object_dict.items():
512 for aggregate_value in self.generator_registry[class_]:
513 query = aggregate_value.update_query(objects)
514 if query is not None:
515 session.execute(query)
516
517
518manager = AggregationManager()
519manager.register_listeners()
520
521
522def aggregated(relationship, column):
523 """
524 Decorator that generates an aggregated attribute. The decorated function
525 should return an aggregate select expression.
526
527 :param relationship:
528 Defines the relationship of which the aggregate is calculated from.
529 The class needs to have given relationship in order to calculate the
530 aggregate.
531 :param column:
532 SQLAlchemy Column object. The column definition of this aggregate
533 attribute.
534 """
535
536 def wraps(func):
537 return AggregatedAttribute(func, relationship, column)
538
539 return wraps