Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy_utils/aggregates.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

102 statements  

1""" 

2SQLAlchemy-Utils provides way of automatically calculating aggregate values of 

3related models and saving them to parent model. 

4 

5This solution is inspired by RoR counter cache, 

6`counter_culture`_ and `stackoverflow reply by Michael Bayer`_. 

7 

8Why? 

9---- 

10 

11Many times you may have situations where you need to calculate dynamically some 

12aggregate value for given model. Some simple examples include: 

13 

14- Number of products in a catalog 

15- Average rating for movie 

16- Latest forum post 

17- Total price of orders for given customer 

18 

19Now all these aggregates can be elegantly implemented with SQLAlchemy 

20column_property_ function. However when your data grows calculating these 

21values on the fly might start to hurt the performance of your application. The 

22more aggregates you are using the more performance penalty you get. 

23 

24This module provides way of calculating these values automatically and 

25efficiently at the time of modification rather than on the fly. 

26 

27 

28Features 

29-------- 

30 

31* Automatically updates aggregate columns when aggregated values change 

32* Supports aggregate values through arbitrary number levels of relations 

33* Highly optimized: uses single query per transaction per aggregate column 

34* Aggregated columns can be of any data type and use any selectable scalar 

35 expression 

36 

37 

38.. _column_property: 

39 https://docs.sqlalchemy.org/en/latest/orm/mapped_sql_expr.html#using-column-property 

40.. _counter_culture: https://github.com/magnusvk/counter_culture 

41.. _stackoverflow reply by Michael Bayer: 

42 https://stackoverflow.com/a/13765857/520932 

43 

44 

45Simple aggregates 

46----------------- 

47 

48:: 

49 

50 from sqlalchemy_utils import aggregated 

51 

52 

53 class Thread(Base): 

54 __tablename__ = 'thread' 

55 id = sa.Column(sa.Integer, primary_key=True) 

56 name = sa.Column(sa.Unicode(255)) 

57 

58 @aggregated('comments', sa.Column(sa.Integer)) 

59 def comment_count(self): 

60 return sa.func.count('1') 

61 

62 comments = sa.orm.relationship( 

63 'Comment', 

64 backref='thread' 

65 ) 

66 

67 

68 class Comment(Base): 

69 __tablename__ = 'comment' 

70 id = sa.Column(sa.Integer, primary_key=True) 

71 content = sa.Column(sa.UnicodeText) 

72 thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id)) 

73 

74 

75 thread = Thread(name='SQLAlchemy development') 

76 thread.comments.append(Comment('Going good!')) 

77 thread.comments.append(Comment('Great new features!')) 

78 

79 session.add(thread) 

80 session.commit() 

81 

82 thread.comment_count # 2 

83 

84 

85 

86Custom aggregate expressions 

87---------------------------- 

88 

89Aggregate expression can be virtually any SQL expression not just a simple 

90function taking one parameter. You can try things such as subqueries and 

91different kinds of functions. 

92 

93In the following example we have a Catalog of products where each catalog 

94knows the net worth of its products. 

95 

96:: 

97 

98 

99 from sqlalchemy_utils import aggregated 

100 

101 

102 class Catalog(Base): 

103 __tablename__ = 'catalog' 

104 id = sa.Column(sa.Integer, primary_key=True) 

105 name = sa.Column(sa.Unicode(255)) 

106 

107 @aggregated('products', sa.Column(sa.Integer)) 

108 def net_worth(self): 

109 return sa.func.sum(Product.price) 

110 

111 products = sa.orm.relationship('Product') 

112 

113 

114 class Product(Base): 

115 __tablename__ = 'product' 

116 id = sa.Column(sa.Integer, primary_key=True) 

117 name = sa.Column(sa.Unicode(255)) 

118 price = sa.Column(sa.Numeric) 

119 

120 catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id)) 

121 

122 

123Now the net_worth column of Catalog model will be automatically whenever: 

124 

125* A new product is added to the catalog 

126* A product is deleted from the catalog 

127* The price of catalog product is changed 

128 

129 

130:: 

131 

132 

133 from decimal import Decimal 

134 

135 

136 product1 = Product(name='Some product', price=Decimal(1000)) 

137 product2 = Product(name='Some other product', price=Decimal(500)) 

138 

139 

140 catalog = Catalog( 

141 name='My first catalog', 

142 products=[ 

143 product1, 

144 product2 

145 ] 

146 ) 

147 session.add(catalog) 

148 session.commit() 

149 

150 session.refresh(catalog) 

151 catalog.net_worth # 1500 

152 

153 session.delete(product2) 

154 session.commit() 

155 session.refresh(catalog) 

156 

157 catalog.net_worth # 1000 

158 

159 product1.price = 2000 

160 session.commit() 

161 session.refresh(catalog) 

162 

163 catalog.net_worth # 2000 

164 

165 

166 

167 

168Multiple aggregates per class 

169----------------------------- 

170 

171Sometimes you may need to define multiple aggregate values for same class. If 

172you need to define lots of relationships pointing to same class, remember to 

173define the relationships as viewonly when possible. 

174 

175 

176:: 

177 

178 

179 from sqlalchemy_utils import aggregated 

180 

181 

182 class Customer(Base): 

183 __tablename__ = 'customer' 

184 id = sa.Column(sa.Integer, primary_key=True) 

185 name = sa.Column(sa.Unicode(255)) 

186 

187 @aggregated('orders', sa.Column(sa.Integer)) 

188 def orders_sum(self): 

189 return sa.func.sum(Order.price) 

190 

191 @aggregated('invoiced_orders', sa.Column(sa.Integer)) 

192 def invoiced_orders_sum(self): 

193 return sa.func.sum(Order.price) 

194 

195 orders = sa.orm.relationship('Order') 

196 

197 invoiced_orders = sa.orm.relationship( 

198 'Order', 

199 primaryjoin= 

200 'sa.and_(Order.customer_id == Customer.id, Order.invoiced)', 

201 viewonly=True 

202 ) 

203 

204 

205 class Order(Base): 

206 __tablename__ = 'order' 

207 id = sa.Column(sa.Integer, primary_key=True) 

208 name = sa.Column(sa.Unicode(255)) 

209 price = sa.Column(sa.Numeric) 

210 invoiced = sa.Column(sa.Boolean, default=False) 

211 customer_id = sa.Column(sa.Integer, sa.ForeignKey(Customer.id)) 

212 

213 

214Many-to-Many aggregates 

215----------------------- 

216 

217Aggregate expressions also support many-to-many relationships. The usual use 

218scenarios includes things such as: 

219 

2201. Friend count of a user 

2212. Group count where given user belongs to 

222 

223:: 

224 

225 

226 user_group = sa.Table('user_group', Base.metadata, 

227 sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')), 

228 sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) 

229 ) 

230 

231 

232 class User(Base): 

233 __tablename__ = 'user' 

234 id = sa.Column(sa.Integer, primary_key=True) 

235 name = sa.Column(sa.Unicode(255)) 

236 

237 @aggregated('groups', sa.Column(sa.Integer, default=0)) 

238 def group_count(self): 

239 return sa.func.count('1') 

240 

241 groups = sa.orm.relationship( 

242 'Group', 

243 backref='users', 

244 secondary=user_group 

245 ) 

246 

247 

248 class Group(Base): 

249 __tablename__ = 'group' 

250 id = sa.Column(sa.Integer, primary_key=True) 

251 name = sa.Column(sa.Unicode(255)) 

252 

253 

254 

255 user = User(name='John Matrix') 

256 user.groups = [Group(name='Group A'), Group(name='Group B')] 

257 

258 session.add(user) 

259 session.commit() 

260 

261 session.refresh(user) 

262 user.group_count # 2 

263 

264 

265Multi-level aggregates 

266---------------------- 

267 

268Aggregates can span across multiple relationships. In the following example 

269each Catalog has a net_worth which is the sum of all products in all 

270categories. 

271 

272 

273:: 

274 

275 

276 from sqlalchemy_utils import aggregated 

277 

278 

279 class Catalog(Base): 

280 __tablename__ = 'catalog' 

281 id = sa.Column(sa.Integer, primary_key=True) 

282 name = sa.Column(sa.Unicode(255)) 

283 

284 @aggregated('categories.products', sa.Column(sa.Integer)) 

285 def net_worth(self): 

286 return sa.func.sum(Product.price) 

287 

288 categories = sa.orm.relationship('Category') 

289 

290 

291 class Category(Base): 

292 __tablename__ = 'category' 

293 id = sa.Column(sa.Integer, primary_key=True) 

294 name = sa.Column(sa.Unicode(255)) 

295 

296 catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id)) 

297 

298 products = sa.orm.relationship('Product') 

299 

300 

301 class Product(Base): 

302 __tablename__ = 'product' 

303 id = sa.Column(sa.Integer, primary_key=True) 

304 name = sa.Column(sa.Unicode(255)) 

305 price = sa.Column(sa.Numeric) 

306 

307 category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) 

308 

309 

310Examples 

311-------- 

312 

313Average movie rating 

314^^^^^^^^^^^^^^^^^^^^ 

315 

316:: 

317 

318 

319 from sqlalchemy_utils import aggregated 

320 

321 

322 class Movie(Base): 

323 __tablename__ = 'movie' 

324 id = sa.Column(sa.Integer, primary_key=True) 

325 name = sa.Column(sa.Unicode(255)) 

326 

327 @aggregated('ratings', sa.Column(sa.Numeric)) 

328 def avg_rating(self): 

329 return sa.func.avg(Rating.stars) 

330 

331 ratings = sa.orm.relationship('Rating') 

332 

333 

334 class Rating(Base): 

335 __tablename__ = 'rating' 

336 id = sa.Column(sa.Integer, primary_key=True) 

337 stars = sa.Column(sa.Integer) 

338 

339 movie_id = sa.Column(sa.Integer, sa.ForeignKey(Movie.id)) 

340 

341 

342 movie = Movie('Terminator 2') 

343 movie.ratings.append(Rating(stars=5)) 

344 movie.ratings.append(Rating(stars=4)) 

345 movie.ratings.append(Rating(stars=3)) 

346 session.add(movie) 

347 session.commit() 

348 

349 movie.avg_rating # 4 

350 

351 

352 

353TODO 

354---- 

355 

356* Special consideration should be given to `deadlocks`_. 

357 

358 

359.. _deadlocks: 

360 https://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html 

361 

362""" 

363 

364from collections import defaultdict 

365from weakref import WeakKeyDictionary 

366 

367import sqlalchemy as sa 

368import sqlalchemy.event 

369import sqlalchemy.orm 

370from sqlalchemy.ext.declarative import declared_attr 

371from sqlalchemy.sql.functions import _FunctionGenerator 

372 

373from .functions.orm import get_column_key 

374from .relationships import ( 

375 chained_join, 

376 path_to_relationships, 

377 select_correlated_expression, 

378) 

379 

380aggregated_attrs = WeakKeyDictionary() 

381 

382 

383class AggregatedAttribute(declared_attr): 

384 def __init__(self, fget, relationship, column, *args, **kwargs): 

385 super().__init__(fget, *args, **kwargs) 

386 self.__doc__ = fget.__doc__ 

387 self.column = column 

388 self.relationship = relationship 

389 

390 def __get__(desc, self, cls): 

391 value = (desc.fget, desc.relationship, desc.column) 

392 if cls not in aggregated_attrs: 

393 aggregated_attrs[cls] = [value] 

394 else: 

395 aggregated_attrs[cls].append(value) 

396 return desc.column 

397 

398 

399def local_condition(prop, objects): 

400 pairs = prop.local_remote_pairs 

401 if prop.secondary is not None: 

402 parent_column = pairs[1][0] 

403 fetched_column = pairs[1][0] 

404 else: 

405 parent_column = pairs[0][0] 

406 fetched_column = pairs[0][1] 

407 

408 key = get_column_key(prop.mapper, fetched_column) 

409 

410 values = [] 

411 for obj in objects: 

412 try: 

413 values.append(getattr(obj, key)) 

414 except sa.orm.exc.ObjectDeletedError: 

415 pass 

416 

417 if values: 

418 return parent_column.in_(values) 

419 

420 

421def aggregate_expression(expr, class_): 

422 if isinstance(expr, sa.sql.visitors.Visitable): 

423 return expr 

424 elif isinstance(expr, _FunctionGenerator): 

425 return expr(sa.sql.text('1')) 

426 else: 

427 return expr(class_) 

428 

429 

430class AggregatedValue: 

431 def __init__(self, class_, attr, path, expr): 

432 self.class_ = class_ 

433 self.attr = attr 

434 self.path = path 

435 self.relationships = list(reversed(path_to_relationships(path, class_))) 

436 self.expr = aggregate_expression(expr, class_) 

437 

438 @property 

439 def aggregate_query(self): 

440 query = select_correlated_expression( 

441 self.class_, self.expr, self.path, self.relationships[0].mapper.class_ 

442 ) 

443 

444 return query.scalar_subquery() 

445 

446 def update_query(self, objects): 

447 table = self.class_.__table__ 

448 query = table.update().values({self.attr: self.aggregate_query}) 

449 if len(self.relationships) == 1: 

450 prop = self.relationships[-1].property 

451 condition = local_condition(prop, objects) 

452 if condition is not None: 

453 return query.where(condition) 

454 else: 

455 # Builds query such as: 

456 # 

457 # UPDATE catalog SET product_count = (aggregate_query) 

458 # WHERE id IN ( 

459 # SELECT catalog_id 

460 # FROM category 

461 # INNER JOIN sub_category 

462 # ON category.id = sub_category.category_id 

463 # WHERE sub_category.id IN (product_sub_category_ids) 

464 # ) 

465 property_ = self.relationships[-1].property 

466 remote_pairs = property_.local_remote_pairs 

467 local = remote_pairs[0][0] 

468 remote = remote_pairs[0][1] 

469 condition = local_condition(self.relationships[0].property, objects) 

470 if condition is not None: 

471 return query.where( 

472 local.in_( 

473 sa.select(remote) 

474 .select_from(chained_join(*reversed(self.relationships))) 

475 .where(condition) 

476 ) 

477 ) 

478 

479 

480class AggregationManager: 

481 def __init__(self): 

482 self.reset() 

483 

484 def reset(self): 

485 self.generator_registry = defaultdict(list) 

486 

487 def register_listeners(self): 

488 sa.event.listen( 

489 sa.orm.Mapper, 'after_configured', self.update_generator_registry 

490 ) 

491 sa.event.listen( 

492 sa.orm.session.Session, 'after_flush', self.construct_aggregate_queries 

493 ) 

494 

495 def update_generator_registry(self): 

496 for class_, attrs in aggregated_attrs.items(): 

497 for expr, path, column in attrs: 

498 value = AggregatedValue( 

499 class_=class_, attr=column, path=path, expr=expr(class_) 

500 ) 

501 key = value.relationships[0].mapper.class_ 

502 self.generator_registry[key].append(value) 

503 

504 def construct_aggregate_queries(self, session, ctx): 

505 object_dict = defaultdict(list) 

506 for obj in session: 

507 for class_ in self.generator_registry: 

508 if isinstance(obj, class_): 

509 object_dict[class_].append(obj) 

510 

511 for class_, objects in object_dict.items(): 

512 for aggregate_value in self.generator_registry[class_]: 

513 query = aggregate_value.update_query(objects) 

514 if query is not None: 

515 session.execute(query) 

516 

517 

518manager = AggregationManager() 

519manager.register_listeners() 

520 

521 

522def aggregated(relationship, column): 

523 """ 

524 Decorator that generates an aggregated attribute. The decorated function 

525 should return an aggregate select expression. 

526 

527 :param relationship: 

528 Defines the relationship of which the aggregate is calculated from. 

529 The class needs to have given relationship in order to calculate the 

530 aggregate. 

531 :param column: 

532 SQLAlchemy Column object. The column definition of this aggregate 

533 attribute. 

534 """ 

535 

536 def wraps(func): 

537 return AggregatedAttribute(func, relationship, column) 

538 

539 return wraps