Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/flask_sqlalchemy/__init__.py: 19%

524 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# -*- coding: utf-8 -*- 

2from __future__ import absolute_import 

3 

4import functools 

5import os 

6import sys 

7import time 

8import warnings 

9from math import ceil 

10from operator import itemgetter 

11from threading import Lock 

12 

13import sqlalchemy 

14from flask import _app_ctx_stack, abort, current_app, request 

15from flask.signals import Namespace 

16from sqlalchemy import event, inspect, orm 

17from sqlalchemy.engine.url import make_url 

18from sqlalchemy.orm.exc import UnmappedClassError 

19from sqlalchemy.orm.session import Session as SessionBase 

20 

21from ._compat import itervalues, string_types, xrange 

22from .model import DefaultMeta 

23from .model import Model 

24from . import utils 

25 

26try: 

27 from sqlalchemy.orm import declarative_base 

28 from sqlalchemy.orm import DeclarativeMeta 

29except ImportError: 

30 # SQLAlchemy <= 1.3 

31 from sqlalchemy.ext.declarative import declarative_base 

32 from sqlalchemy.ext.declarative import DeclarativeMeta 

33 

34# Scope the session to the current greenlet if greenlet is available, 

35# otherwise fall back to the current thread. 

36try: 

37 from greenlet import getcurrent as _ident_func 

38except ImportError: 

39 try: 

40 from threading import get_ident as _ident_func 

41 except ImportError: 

42 # Python 2.7 

43 from thread import get_ident as _ident_func 

44 

45__version__ = "2.5.1" 

46 

47# the best timer function for the platform 

48if sys.platform == 'win32': 

49 if sys.version_info >= (3, 3): 

50 _timer = time.perf_counter 

51 else: 

52 _timer = time.clock 

53else: 

54 _timer = time.time 

55 

56_signals = Namespace() 

57models_committed = _signals.signal('models-committed') 

58before_models_committed = _signals.signal('before-models-committed') 

59 

60 

61def _sa_url_set(url, **kwargs): 

62 try: 

63 url = url.set(**kwargs) 

64 except AttributeError: 

65 # SQLAlchemy <= 1.3 

66 for key, value in kwargs.items(): 

67 setattr(url, key, value) 

68 

69 return url 

70 

71 

72def _sa_url_query_setdefault(url, **kwargs): 

73 query = dict(url.query) 

74 

75 for key, value in kwargs.items(): 

76 query.setdefault(key, value) 

77 

78 return _sa_url_set(url, query=query) 

79 

80 

81def _make_table(db): 

82 def _make_table(*args, **kwargs): 

83 if len(args) > 1 and isinstance(args[1], db.Column): 

84 args = (args[0], db.metadata) + args[1:] 

85 info = kwargs.pop('info', None) or {} 

86 info.setdefault('bind_key', None) 

87 kwargs['info'] = info 

88 return sqlalchemy.Table(*args, **kwargs) 

89 return _make_table 

90 

91 

92def _set_default_query_class(d, cls): 

93 if 'query_class' not in d: 

94 d['query_class'] = cls 

95 

96 

97def _wrap_with_default_query_class(fn, cls): 

98 @functools.wraps(fn) 

99 def newfn(*args, **kwargs): 

100 _set_default_query_class(kwargs, cls) 

101 if "backref" in kwargs: 

102 backref = kwargs['backref'] 

103 if isinstance(backref, string_types): 

104 backref = (backref, {}) 

105 _set_default_query_class(backref[1], cls) 

106 return fn(*args, **kwargs) 

107 return newfn 

108 

109 

110def _include_sqlalchemy(obj, cls): 

111 for module in sqlalchemy, sqlalchemy.orm: 

112 for key in module.__all__: 

113 if not hasattr(obj, key): 

114 setattr(obj, key, getattr(module, key)) 

115 # Note: obj.Table does not attempt to be a SQLAlchemy Table class. 

116 obj.Table = _make_table(obj) 

117 obj.relationship = _wrap_with_default_query_class(obj.relationship, cls) 

118 obj.relation = _wrap_with_default_query_class(obj.relation, cls) 

119 obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader, cls) 

120 obj.event = event 

121 

122 

123class _DebugQueryTuple(tuple): 

124 statement = property(itemgetter(0)) 

125 parameters = property(itemgetter(1)) 

126 start_time = property(itemgetter(2)) 

127 end_time = property(itemgetter(3)) 

128 context = property(itemgetter(4)) 

129 

130 @property 

131 def duration(self): 

132 return self.end_time - self.start_time 

133 

134 def __repr__(self): 

135 return '<query statement="%s" parameters=%r duration=%.03f>' % ( 

136 self.statement, 

137 self.parameters, 

138 self.duration 

139 ) 

140 

141 

142def _calling_context(app_path): 

143 frm = sys._getframe(1) 

144 while frm.f_back is not None: 

145 name = frm.f_globals.get('__name__') 

146 if name and (name == app_path or name.startswith(app_path + '.')): 

147 funcname = frm.f_code.co_name 

148 return '%s:%s (%s)' % ( 

149 frm.f_code.co_filename, 

150 frm.f_lineno, 

151 funcname 

152 ) 

153 frm = frm.f_back 

154 return '<unknown>' 

155 

156 

157class SignallingSession(SessionBase): 

158 """The signalling session is the default session that Flask-SQLAlchemy 

159 uses. It extends the default session system with bind selection and 

160 modification tracking. 

161 

162 If you want to use a different session you can override the 

163 :meth:`SQLAlchemy.create_session` function. 

164 

165 .. versionadded:: 2.0 

166 

167 .. versionadded:: 2.1 

168 The `binds` option was added, which allows a session to be joined 

169 to an external transaction. 

170 """ 

171 

172 def __init__(self, db, autocommit=False, autoflush=True, **options): 

173 #: The application that this session belongs to. 

174 self.app = app = db.get_app() 

175 track_modifications = app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] 

176 bind = options.pop('bind', None) or db.engine 

177 binds = options.pop('binds', db.get_binds(app)) 

178 

179 if track_modifications is None or track_modifications: 

180 _SessionSignalEvents.register(self) 

181 

182 SessionBase.__init__( 

183 self, autocommit=autocommit, autoflush=autoflush, 

184 bind=bind, binds=binds, **options 

185 ) 

186 

187 def get_bind(self, mapper=None, clause=None): 

188 """Return the engine or connection for a given model or 

189 table, using the ``__bind_key__`` if it is set. 

190 """ 

191 # mapper is None if someone tries to just get a connection 

192 if mapper is not None: 

193 try: 

194 # SA >= 1.3 

195 persist_selectable = mapper.persist_selectable 

196 except AttributeError: 

197 # SA < 1.3 

198 persist_selectable = mapper.mapped_table 

199 

200 info = getattr(persist_selectable, 'info', {}) 

201 bind_key = info.get('bind_key') 

202 if bind_key is not None: 

203 state = get_state(self.app) 

204 return state.db.get_engine(self.app, bind=bind_key) 

205 return SessionBase.get_bind(self, mapper, clause) 

206 

207 

208class _SessionSignalEvents(object): 

209 @classmethod 

210 def register(cls, session): 

211 if not hasattr(session, '_model_changes'): 

212 session._model_changes = {} 

213 

214 event.listen(session, 'before_flush', cls.record_ops) 

215 event.listen(session, 'before_commit', cls.record_ops) 

216 event.listen(session, 'before_commit', cls.before_commit) 

217 event.listen(session, 'after_commit', cls.after_commit) 

218 event.listen(session, 'after_rollback', cls.after_rollback) 

219 

220 @classmethod 

221 def unregister(cls, session): 

222 if hasattr(session, '_model_changes'): 

223 del session._model_changes 

224 

225 event.remove(session, 'before_flush', cls.record_ops) 

226 event.remove(session, 'before_commit', cls.record_ops) 

227 event.remove(session, 'before_commit', cls.before_commit) 

228 event.remove(session, 'after_commit', cls.after_commit) 

229 event.remove(session, 'after_rollback', cls.after_rollback) 

230 

231 @staticmethod 

232 def record_ops(session, flush_context=None, instances=None): 

233 try: 

234 d = session._model_changes 

235 except AttributeError: 

236 return 

237 

238 for targets, operation in ((session.new, 'insert'), (session.dirty, 'update'), (session.deleted, 'delete')): 

239 for target in targets: 

240 state = inspect(target) 

241 key = state.identity_key if state.has_identity else id(target) 

242 d[key] = (target, operation) 

243 

244 @staticmethod 

245 def before_commit(session): 

246 try: 

247 d = session._model_changes 

248 except AttributeError: 

249 return 

250 

251 if d: 

252 before_models_committed.send(session.app, changes=list(d.values())) 

253 

254 @staticmethod 

255 def after_commit(session): 

256 try: 

257 d = session._model_changes 

258 except AttributeError: 

259 return 

260 

261 if d: 

262 models_committed.send(session.app, changes=list(d.values())) 

263 d.clear() 

264 

265 @staticmethod 

266 def after_rollback(session): 

267 try: 

268 d = session._model_changes 

269 except AttributeError: 

270 return 

271 

272 d.clear() 

273 

274 

275class _EngineDebuggingSignalEvents(object): 

276 """Sets up handlers for two events that let us track the execution time of 

277 queries.""" 

278 

279 def __init__(self, engine, import_name): 

280 self.engine = engine 

281 self.app_package = import_name 

282 

283 def register(self): 

284 event.listen( 

285 self.engine, 'before_cursor_execute', self.before_cursor_execute 

286 ) 

287 event.listen( 

288 self.engine, 'after_cursor_execute', self.after_cursor_execute 

289 ) 

290 

291 def before_cursor_execute( 

292 self, conn, cursor, statement, parameters, context, executemany 

293 ): 

294 if current_app: 

295 context._query_start_time = _timer() 

296 

297 def after_cursor_execute( 

298 self, conn, cursor, statement, parameters, context, executemany 

299 ): 

300 if current_app: 

301 try: 

302 queries = _app_ctx_stack.top.sqlalchemy_queries 

303 except AttributeError: 

304 queries = _app_ctx_stack.top.sqlalchemy_queries = [] 

305 

306 queries.append(_DebugQueryTuple(( 

307 statement, parameters, context._query_start_time, _timer(), 

308 _calling_context(self.app_package) 

309 ))) 

310 

311 

312def get_debug_queries(): 

313 """In debug mode Flask-SQLAlchemy will log all the SQL queries sent 

314 to the database. This information is available until the end of request 

315 which makes it possible to easily ensure that the SQL generated is the 

316 one expected on errors or in unittesting. If you don't want to enable 

317 the DEBUG mode for your unittests you can also enable the query 

318 recording by setting the ``'SQLALCHEMY_RECORD_QUERIES'`` config variable 

319 to `True`. This is automatically enabled if Flask is in testing mode. 

320 

321 The value returned will be a list of named tuples with the following 

322 attributes: 

323 

324 `statement` 

325 The SQL statement issued 

326 

327 `parameters` 

328 The parameters for the SQL statement 

329 

330 `start_time` / `end_time` 

331 Time the query started / the results arrived. Please keep in mind 

332 that the timer function used depends on your platform. These 

333 values are only useful for sorting or comparing. They do not 

334 necessarily represent an absolute timestamp. 

335 

336 `duration` 

337 Time the query took in seconds 

338 

339 `context` 

340 A string giving a rough estimation of where in your application 

341 query was issued. The exact format is undefined so don't try 

342 to reconstruct filename or function name. 

343 """ 

344 return getattr(_app_ctx_stack.top, 'sqlalchemy_queries', []) 

345 

346 

347class Pagination(object): 

348 """Internal helper class returned by :meth:`BaseQuery.paginate`. You 

349 can also construct it from any other SQLAlchemy query object if you are 

350 working with other libraries. Additionally it is possible to pass `None` 

351 as query object in which case the :meth:`prev` and :meth:`next` will 

352 no longer work. 

353 """ 

354 

355 def __init__(self, query, page, per_page, total, items): 

356 #: the unlimited query object that was used to create this 

357 #: pagination object. 

358 self.query = query 

359 #: the current page number (1 indexed) 

360 self.page = page 

361 #: the number of items to be displayed on a page. 

362 self.per_page = per_page 

363 #: the total number of items matching the query 

364 self.total = total 

365 #: the items for the current page 

366 self.items = items 

367 

368 @property 

369 def pages(self): 

370 """The total number of pages""" 

371 if self.per_page == 0: 

372 pages = 0 

373 else: 

374 pages = int(ceil(self.total / float(self.per_page))) 

375 return pages 

376 

377 def prev(self, error_out=False): 

378 """Returns a :class:`Pagination` object for the previous page.""" 

379 assert self.query is not None, 'a query object is required ' \ 

380 'for this method to work' 

381 return self.query.paginate(self.page - 1, self.per_page, error_out) 

382 

383 @property 

384 def prev_num(self): 

385 """Number of the previous page.""" 

386 if not self.has_prev: 

387 return None 

388 return self.page - 1 

389 

390 @property 

391 def has_prev(self): 

392 """True if a previous page exists""" 

393 return self.page > 1 

394 

395 def next(self, error_out=False): 

396 """Returns a :class:`Pagination` object for the next page.""" 

397 assert self.query is not None, 'a query object is required ' \ 

398 'for this method to work' 

399 return self.query.paginate(self.page + 1, self.per_page, error_out) 

400 

401 @property 

402 def has_next(self): 

403 """True if a next page exists.""" 

404 return self.page < self.pages 

405 

406 @property 

407 def next_num(self): 

408 """Number of the next page""" 

409 if not self.has_next: 

410 return None 

411 return self.page + 1 

412 

413 def iter_pages(self, left_edge=2, left_current=2, 

414 right_current=5, right_edge=2): 

415 """Iterates over the page numbers in the pagination. The four 

416 parameters control the thresholds how many numbers should be produced 

417 from the sides. Skipped page numbers are represented as `None`. 

418 This is how you could render such a pagination in the templates: 

419 

420 .. sourcecode:: html+jinja 

421 

422 {% macro render_pagination(pagination, endpoint) %} 

423 <div class=pagination> 

424 {%- for page in pagination.iter_pages() %} 

425 {% if page %} 

426 {% if page != pagination.page %} 

427 <a href="{{ url_for(endpoint, page=page) }}">{{ page }}</a> 

428 {% else %} 

429 <strong>{{ page }}</strong> 

430 {% endif %} 

431 {% else %} 

432 <span class=ellipsis>…</span> 

433 {% endif %} 

434 {%- endfor %} 

435 </div> 

436 {% endmacro %} 

437 """ 

438 last = 0 

439 for num in xrange(1, self.pages + 1): 

440 if num <= left_edge or \ 

441 (num > self.page - left_current - 1 and 

442 num < self.page + right_current) or \ 

443 num > self.pages - right_edge: 

444 if last + 1 != num: 

445 yield None 

446 yield num 

447 last = num 

448 

449 

450class BaseQuery(orm.Query): 

451 """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with convenience methods for querying in a web application. 

452 

453 This is the default :attr:`~Model.query` object used for models, and exposed as :attr:`~SQLAlchemy.Query`. 

454 Override the query class for an individual model by subclassing this and setting :attr:`~Model.query_class`. 

455 """ 

456 

457 def get_or_404(self, ident, description=None): 

458 """Like :meth:`get` but aborts with 404 if not found instead of returning ``None``.""" 

459 

460 rv = self.get(ident) 

461 if rv is None: 

462 abort(404, description=description) 

463 return rv 

464 

465 def first_or_404(self, description=None): 

466 """Like :meth:`first` but aborts with 404 if not found instead of returning ``None``.""" 

467 

468 rv = self.first() 

469 if rv is None: 

470 abort(404, description=description) 

471 return rv 

472 

473 def paginate(self, page=None, per_page=None, error_out=True, max_per_page=None): 

474 """Returns ``per_page`` items from page ``page``. 

475 

476 If ``page`` or ``per_page`` are ``None``, they will be retrieved from 

477 the request query. If ``max_per_page`` is specified, ``per_page`` will 

478 be limited to that value. If there is no request or they aren't in the 

479 query, they default to 1 and 20 respectively. 

480 

481 When ``error_out`` is ``True`` (default), the following rules will 

482 cause a 404 response: 

483 

484 * No items are found and ``page`` is not 1. 

485 * ``page`` is less than 1, or ``per_page`` is negative. 

486 * ``page`` or ``per_page`` are not ints. 

487 

488 When ``error_out`` is ``False``, ``page`` and ``per_page`` default to 

489 1 and 20 respectively. 

490 

491 Returns a :class:`Pagination` object. 

492 """ 

493 

494 if request: 

495 if page is None: 

496 try: 

497 page = int(request.args.get('page', 1)) 

498 except (TypeError, ValueError): 

499 if error_out: 

500 abort(404) 

501 

502 page = 1 

503 

504 if per_page is None: 

505 try: 

506 per_page = int(request.args.get('per_page', 20)) 

507 except (TypeError, ValueError): 

508 if error_out: 

509 abort(404) 

510 

511 per_page = 20 

512 else: 

513 if page is None: 

514 page = 1 

515 

516 if per_page is None: 

517 per_page = 20 

518 

519 if max_per_page is not None: 

520 per_page = min(per_page, max_per_page) 

521 

522 if page < 1: 

523 if error_out: 

524 abort(404) 

525 else: 

526 page = 1 

527 

528 if per_page < 0: 

529 if error_out: 

530 abort(404) 

531 else: 

532 per_page = 20 

533 

534 items = self.limit(per_page).offset((page - 1) * per_page).all() 

535 

536 if not items and page != 1 and error_out: 

537 abort(404) 

538 

539 total = self.order_by(None).count() 

540 

541 return Pagination(self, page, per_page, total, items) 

542 

543 

544class _QueryProperty(object): 

545 def __init__(self, sa): 

546 self.sa = sa 

547 

548 def __get__(self, obj, type): 

549 try: 

550 mapper = orm.class_mapper(type) 

551 if mapper: 

552 return type.query_class(mapper, session=self.sa.session()) 

553 except UnmappedClassError: 

554 return None 

555 

556 

557def _record_queries(app): 

558 if app.debug: 

559 return True 

560 rq = app.config['SQLALCHEMY_RECORD_QUERIES'] 

561 if rq is not None: 

562 return rq 

563 return bool(app.config.get('TESTING')) 

564 

565 

566class _EngineConnector(object): 

567 

568 def __init__(self, sa, app, bind=None): 

569 self._sa = sa 

570 self._app = app 

571 self._engine = None 

572 self._connected_for = None 

573 self._bind = bind 

574 self._lock = Lock() 

575 

576 def get_uri(self): 

577 if self._bind is None: 

578 return self._app.config['SQLALCHEMY_DATABASE_URI'] 

579 binds = self._app.config.get('SQLALCHEMY_BINDS') or () 

580 assert self._bind in binds, \ 

581 'Bind %r is not specified. Set it in the SQLALCHEMY_BINDS ' \ 

582 'configuration variable' % self._bind 

583 return binds[self._bind] 

584 

585 def get_engine(self): 

586 with self._lock: 

587 uri = self.get_uri() 

588 echo = self._app.config['SQLALCHEMY_ECHO'] 

589 if (uri, echo) == self._connected_for: 

590 return self._engine 

591 

592 sa_url = make_url(uri) 

593 sa_url, options = self.get_options(sa_url, echo) 

594 self._engine = rv = self._sa.create_engine(sa_url, options) 

595 

596 if _record_queries(self._app): 

597 _EngineDebuggingSignalEvents(self._engine, 

598 self._app.import_name).register() 

599 

600 self._connected_for = (uri, echo) 

601 

602 return rv 

603 

604 def get_options(self, sa_url, echo): 

605 options = {} 

606 

607 options = self._sa.apply_pool_defaults(self._app, options) 

608 sa_url, options = self._sa.apply_driver_hacks(self._app, sa_url, options) 

609 

610 if echo: 

611 options['echo'] = echo 

612 

613 # Give the config options set by a developer explicitly priority 

614 # over decisions FSA makes. 

615 options.update(self._app.config['SQLALCHEMY_ENGINE_OPTIONS']) 

616 

617 # Give options set in SQLAlchemy.__init__() ultimate priority 

618 options.update(self._sa._engine_options) 

619 

620 return sa_url, options 

621 

622 

623def get_state(app): 

624 """Gets the state for the application""" 

625 assert 'sqlalchemy' in app.extensions, \ 

626 'The sqlalchemy extension was not registered to the current ' \ 

627 'application. Please make sure to call init_app() first.' 

628 return app.extensions['sqlalchemy'] 

629 

630 

631class _SQLAlchemyState(object): 

632 """Remembers configuration for the (db, app) tuple.""" 

633 

634 def __init__(self, db): 

635 self.db = db 

636 self.connectors = {} 

637 

638 

639class SQLAlchemy(object): 

640 """This class is used to control the SQLAlchemy integration to one 

641 or more Flask applications. Depending on how you initialize the 

642 object it is usable right away or will attach as needed to a 

643 Flask application. 

644 

645 There are two usage modes which work very similarly. One is binding 

646 the instance to a very specific Flask application:: 

647 

648 app = Flask(__name__) 

649 db = SQLAlchemy(app) 

650 

651 The second possibility is to create the object once and configure the 

652 application later to support it:: 

653 

654 db = SQLAlchemy() 

655 

656 def create_app(): 

657 app = Flask(__name__) 

658 db.init_app(app) 

659 return app 

660 

661 The difference between the two is that in the first case methods like 

662 :meth:`create_all` and :meth:`drop_all` will work all the time but in 

663 the second case a :meth:`flask.Flask.app_context` has to exist. 

664 

665 By default Flask-SQLAlchemy will apply some backend-specific settings 

666 to improve your experience with them. 

667 

668 As of SQLAlchemy 0.6 SQLAlchemy 

669 will probe the library for native unicode support. If it detects 

670 unicode it will let the library handle that, otherwise do that itself. 

671 Sometimes this detection can fail in which case you might want to set 

672 ``use_native_unicode`` (or the ``SQLALCHEMY_NATIVE_UNICODE`` configuration 

673 key) to ``False``. Note that the configuration key overrides the 

674 value you pass to the constructor. Direct support for ``use_native_unicode`` 

675 and SQLALCHEMY_NATIVE_UNICODE are deprecated as of v2.4 and will be removed 

676 in v3.0. ``engine_options`` and ``SQLALCHEMY_ENGINE_OPTIONS`` may be used 

677 instead. 

678 

679 This class also provides access to all the SQLAlchemy functions and classes 

680 from the :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` modules. So you can 

681 declare models like this:: 

682 

683 class User(db.Model): 

684 username = db.Column(db.String(80), unique=True) 

685 pw_hash = db.Column(db.String(80)) 

686 

687 You can still use :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, but 

688 note that Flask-SQLAlchemy customizations are available only through an 

689 instance of this :class:`SQLAlchemy` class. Query classes default to 

690 :class:`BaseQuery` for `db.Query`, `db.Model.query_class`, and the default 

691 query_class for `db.relationship` and `db.backref`. If you use these 

692 interfaces through :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, 

693 the default query class will be that of :mod:`sqlalchemy`. 

694 

695 .. admonition:: Check types carefully 

696 

697 Don't perform type or `isinstance` checks against `db.Table`, which 

698 emulates `Table` behavior but is not a class. `db.Table` exposes the 

699 `Table` interface, but is a function which allows omission of metadata. 

700 

701 The ``session_options`` parameter, if provided, is a dict of parameters 

702 to be passed to the session constructor. See :class:`~sqlalchemy.orm.session.Session` 

703 for the standard options. 

704 

705 The ``engine_options`` parameter, if provided, is a dict of parameters 

706 to be passed to create engine. See :func:`~sqlalchemy.create_engine` 

707 for the standard options. The values given here will be merged with and 

708 override anything set in the ``'SQLALCHEMY_ENGINE_OPTIONS'`` config 

709 variable or othewise set by this library. 

710 

711 .. versionadded:: 0.10 

712 The `session_options` parameter was added. 

713 

714 .. versionadded:: 0.16 

715 `scopefunc` is now accepted on `session_options`. It allows specifying 

716 a custom function which will define the SQLAlchemy session's scoping. 

717 

718 .. versionadded:: 2.1 

719 The `metadata` parameter was added. This allows for setting custom 

720 naming conventions among other, non-trivial things. 

721 

722 The `query_class` parameter was added, to allow customisation 

723 of the query class, in place of the default of :class:`BaseQuery`. 

724 

725 The `model_class` parameter was added, which allows a custom model 

726 class to be used in place of :class:`Model`. 

727 

728 .. versionchanged:: 2.1 

729 Utilise the same query class across `session`, `Model.query` and `Query`. 

730 

731 .. versionadded:: 2.4 

732 The `engine_options` parameter was added. 

733 

734 .. versionchanged:: 2.4 

735 The `use_native_unicode` parameter was deprecated. 

736 

737 .. versionchanged:: 2.4.3 

738 ``COMMIT_ON_TEARDOWN`` is deprecated and will be removed in 

739 version 3.1. Call ``db.session.commit()`` directly instead. 

740 """ 

741 

742 #: Default query class used by :attr:`Model.query` and other queries. 

743 #: Customize this by passing ``query_class`` to :func:`SQLAlchemy`. 

744 #: Defaults to :class:`BaseQuery`. 

745 Query = None 

746 

747 def __init__(self, app=None, use_native_unicode=True, session_options=None, 

748 metadata=None, query_class=BaseQuery, model_class=Model, 

749 engine_options=None): 

750 

751 self.use_native_unicode = use_native_unicode 

752 self.Query = query_class 

753 self.session = self.create_scoped_session(session_options) 

754 self.Model = self.make_declarative_base(model_class, metadata) 

755 self._engine_lock = Lock() 

756 self.app = app 

757 self._engine_options = engine_options or {} 

758 _include_sqlalchemy(self, query_class) 

759 

760 if app is not None: 

761 self.init_app(app) 

762 

763 @property 

764 def metadata(self): 

765 """The metadata associated with ``db.Model``.""" 

766 

767 return self.Model.metadata 

768 

769 def create_scoped_session(self, options=None): 

770 """Create a :class:`~sqlalchemy.orm.scoping.scoped_session` 

771 on the factory from :meth:`create_session`. 

772 

773 An extra key ``'scopefunc'`` can be set on the ``options`` dict to 

774 specify a custom scope function. If it's not provided, Flask's app 

775 context stack identity is used. This will ensure that sessions are 

776 created and removed with the request/response cycle, and should be fine 

777 in most cases. 

778 

779 :param options: dict of keyword arguments passed to session class in 

780 ``create_session`` 

781 """ 

782 

783 if options is None: 

784 options = {} 

785 

786 scopefunc = options.pop('scopefunc', _ident_func) 

787 options.setdefault('query_cls', self.Query) 

788 return orm.scoped_session( 

789 self.create_session(options), scopefunc=scopefunc 

790 ) 

791 

792 def create_session(self, options): 

793 """Create the session factory used by :meth:`create_scoped_session`. 

794 

795 The factory **must** return an object that SQLAlchemy recognizes as a session, 

796 or registering session events may raise an exception. 

797 

798 Valid factories include a :class:`~sqlalchemy.orm.session.Session` 

799 class or a :class:`~sqlalchemy.orm.session.sessionmaker`. 

800 

801 The default implementation creates a ``sessionmaker`` for :class:`SignallingSession`. 

802 

803 :param options: dict of keyword arguments passed to session class 

804 """ 

805 

806 return orm.sessionmaker(class_=SignallingSession, db=self, **options) 

807 

808 def make_declarative_base(self, model, metadata=None): 

809 """Creates the declarative base that all models will inherit from. 

810 

811 :param model: base model class (or a tuple of base classes) to pass 

812 to :func:`~sqlalchemy.ext.declarative.declarative_base`. Or a class 

813 returned from ``declarative_base``, in which case a new base class 

814 is not created. 

815 :param metadata: :class:`~sqlalchemy.MetaData` instance to use, or 

816 none to use SQLAlchemy's default. 

817 

818 .. versionchanged 2.3.0:: 

819 ``model`` can be an existing declarative base in order to support 

820 complex customization such as changing the metaclass. 

821 """ 

822 if not isinstance(model, DeclarativeMeta): 

823 model = declarative_base( 

824 cls=model, 

825 name='Model', 

826 metadata=metadata, 

827 metaclass=DefaultMeta 

828 ) 

829 

830 # if user passed in a declarative base and a metaclass for some reason, 

831 # make sure the base uses the metaclass 

832 if metadata is not None and model.metadata is not metadata: 

833 model.metadata = metadata 

834 

835 if not getattr(model, 'query_class', None): 

836 model.query_class = self.Query 

837 

838 model.query = _QueryProperty(self) 

839 return model 

840 

841 def init_app(self, app): 

842 """This callback can be used to initialize an application for the 

843 use with this database setup. Never use a database in the context 

844 of an application not initialized that way or connections will 

845 leak. 

846 """ 

847 if ( 

848 'SQLALCHEMY_DATABASE_URI' not in app.config and 

849 'SQLALCHEMY_BINDS' not in app.config 

850 ): 

851 warnings.warn( 

852 'Neither SQLALCHEMY_DATABASE_URI nor SQLALCHEMY_BINDS is set. ' 

853 'Defaulting SQLALCHEMY_DATABASE_URI to "sqlite:///:memory:".' 

854 ) 

855 

856 app.config.setdefault('SQLALCHEMY_DATABASE_URI', 'sqlite:///:memory:') 

857 app.config.setdefault('SQLALCHEMY_BINDS', None) 

858 app.config.setdefault('SQLALCHEMY_NATIVE_UNICODE', None) 

859 app.config.setdefault('SQLALCHEMY_ECHO', False) 

860 app.config.setdefault('SQLALCHEMY_RECORD_QUERIES', None) 

861 app.config.setdefault('SQLALCHEMY_POOL_SIZE', None) 

862 app.config.setdefault('SQLALCHEMY_POOL_TIMEOUT', None) 

863 app.config.setdefault('SQLALCHEMY_POOL_RECYCLE', None) 

864 app.config.setdefault('SQLALCHEMY_MAX_OVERFLOW', None) 

865 app.config.setdefault('SQLALCHEMY_COMMIT_ON_TEARDOWN', False) 

866 track_modifications = app.config.setdefault( 

867 'SQLALCHEMY_TRACK_MODIFICATIONS', None 

868 ) 

869 app.config.setdefault('SQLALCHEMY_ENGINE_OPTIONS', {}) 

870 

871 if track_modifications is None: 

872 warnings.warn(FSADeprecationWarning( 

873 'SQLALCHEMY_TRACK_MODIFICATIONS adds significant overhead and ' 

874 'will be disabled by default in the future. Set it to True ' 

875 'or False to suppress this warning.' 

876 )) 

877 

878 # Deprecation warnings for config keys that should be replaced by SQLALCHEMY_ENGINE_OPTIONS. 

879 utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_SIZE', 'pool_size') 

880 utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_TIMEOUT', 'pool_timeout') 

881 utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_RECYCLE', 'pool_recycle') 

882 utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_MAX_OVERFLOW', 'max_overflow') 

883 

884 app.extensions['sqlalchemy'] = _SQLAlchemyState(self) 

885 

886 @app.teardown_appcontext 

887 def shutdown_session(response_or_exc): 

888 if app.config['SQLALCHEMY_COMMIT_ON_TEARDOWN']: 

889 warnings.warn( 

890 "'COMMIT_ON_TEARDOWN' is deprecated and will be" 

891 " removed in version 3.1. Call" 

892 " 'db.session.commit()'` directly instead.", 

893 DeprecationWarning, 

894 ) 

895 

896 if response_or_exc is None: 

897 self.session.commit() 

898 

899 self.session.remove() 

900 return response_or_exc 

901 

902 def apply_pool_defaults(self, app, options): 

903 """ 

904 .. versionchanged:: 2.5 

905 Returns the ``options`` dict, for consistency with 

906 :meth:`apply_driver_hacks`. 

907 """ 

908 def _setdefault(optionkey, configkey): 

909 value = app.config[configkey] 

910 if value is not None: 

911 options[optionkey] = value 

912 _setdefault('pool_size', 'SQLALCHEMY_POOL_SIZE') 

913 _setdefault('pool_timeout', 'SQLALCHEMY_POOL_TIMEOUT') 

914 _setdefault('pool_recycle', 'SQLALCHEMY_POOL_RECYCLE') 

915 _setdefault('max_overflow', 'SQLALCHEMY_MAX_OVERFLOW') 

916 return options 

917 

918 def apply_driver_hacks(self, app, sa_url, options): 

919 """This method is called before engine creation and used to inject 

920 driver specific hacks into the options. The `options` parameter is 

921 a dictionary of keyword arguments that will then be used to call 

922 the :func:`sqlalchemy.create_engine` function. 

923 

924 The default implementation provides some saner defaults for things 

925 like pool sizes for MySQL and sqlite. Also it injects the setting of 

926 `SQLALCHEMY_NATIVE_UNICODE`. 

927 

928 .. versionchanged:: 2.5 

929 Returns ``(sa_url, options)``. SQLAlchemy 1.4 made the URL 

930 immutable, so any changes to it must now be passed back up 

931 to the original caller. 

932 """ 

933 if sa_url.drivername.startswith('mysql'): 

934 sa_url = _sa_url_query_setdefault(sa_url, charset="utf8") 

935 

936 if sa_url.drivername != 'mysql+gaerdbms': 

937 options.setdefault('pool_size', 10) 

938 options.setdefault('pool_recycle', 7200) 

939 elif sa_url.drivername == 'sqlite': 

940 pool_size = options.get('pool_size') 

941 detected_in_memory = False 

942 if sa_url.database in (None, '', ':memory:'): 

943 detected_in_memory = True 

944 from sqlalchemy.pool import StaticPool 

945 options['poolclass'] = StaticPool 

946 if 'connect_args' not in options: 

947 options['connect_args'] = {} 

948 options['connect_args']['check_same_thread'] = False 

949 

950 # we go to memory and the pool size was explicitly set 

951 # to 0 which is fail. Let the user know that 

952 if pool_size == 0: 

953 raise RuntimeError('SQLite in memory database with an ' 

954 'empty queue not possible due to data ' 

955 'loss.') 

956 # if pool size is None or explicitly set to 0 we assume the 

957 # user did not want a queue for this sqlite connection and 

958 # hook in the null pool. 

959 elif not pool_size: 

960 from sqlalchemy.pool import NullPool 

961 options['poolclass'] = NullPool 

962 

963 # if it's not an in memory database we make the path absolute. 

964 if not detected_in_memory: 

965 sa_url = _sa_url_set( 

966 sa_url, database=os.path.join(app.root_path, sa_url.database) 

967 ) 

968 

969 unu = app.config['SQLALCHEMY_NATIVE_UNICODE'] 

970 if unu is None: 

971 unu = self.use_native_unicode 

972 if not unu: 

973 options['use_native_unicode'] = False 

974 

975 if app.config['SQLALCHEMY_NATIVE_UNICODE'] is not None: 

976 warnings.warn( 

977 "The 'SQLALCHEMY_NATIVE_UNICODE' config option is deprecated and will be removed in" 

978 " v3.0. Use 'SQLALCHEMY_ENGINE_OPTIONS' instead.", 

979 DeprecationWarning 

980 ) 

981 if not self.use_native_unicode: 

982 warnings.warn( 

983 "'use_native_unicode' is deprecated and will be removed in v3.0." 

984 " Use the 'engine_options' parameter instead.", 

985 DeprecationWarning 

986 ) 

987 

988 return sa_url, options 

989 

990 @property 

991 def engine(self): 

992 """Gives access to the engine. If the database configuration is bound 

993 to a specific application (initialized with an application) this will 

994 always return a database connection. If however the current application 

995 is used this might raise a :exc:`RuntimeError` if no application is 

996 active at the moment. 

997 """ 

998 return self.get_engine() 

999 

1000 def make_connector(self, app=None, bind=None): 

1001 """Creates the connector for a given state and bind.""" 

1002 return _EngineConnector(self, self.get_app(app), bind) 

1003 

1004 def get_engine(self, app=None, bind=None): 

1005 """Returns a specific engine.""" 

1006 

1007 app = self.get_app(app) 

1008 state = get_state(app) 

1009 

1010 with self._engine_lock: 

1011 connector = state.connectors.get(bind) 

1012 

1013 if connector is None: 

1014 connector = self.make_connector(app, bind) 

1015 state.connectors[bind] = connector 

1016 

1017 return connector.get_engine() 

1018 

1019 def create_engine(self, sa_url, engine_opts): 

1020 """ 

1021 Override this method to have final say over how the SQLAlchemy engine 

1022 is created. 

1023 

1024 In most cases, you will want to use ``'SQLALCHEMY_ENGINE_OPTIONS'`` 

1025 config variable or set ``engine_options`` for :func:`SQLAlchemy`. 

1026 """ 

1027 return sqlalchemy.create_engine(sa_url, **engine_opts) 

1028 

1029 def get_app(self, reference_app=None): 

1030 """Helper method that implements the logic to look up an 

1031 application.""" 

1032 

1033 if reference_app is not None: 

1034 return reference_app 

1035 

1036 if current_app: 

1037 return current_app._get_current_object() 

1038 

1039 if self.app is not None: 

1040 return self.app 

1041 

1042 raise RuntimeError( 

1043 'No application found. Either work inside a view function or push' 

1044 ' an application context. See' 

1045 ' http://flask-sqlalchemy.pocoo.org/contexts/.' 

1046 ) 

1047 

1048 def get_tables_for_bind(self, bind=None): 

1049 """Returns a list of all tables relevant for a bind.""" 

1050 result = [] 

1051 for table in itervalues(self.Model.metadata.tables): 

1052 if table.info.get('bind_key') == bind: 

1053 result.append(table) 

1054 return result 

1055 

1056 def get_binds(self, app=None): 

1057 """Returns a dictionary with a table->engine mapping. 

1058 

1059 This is suitable for use of sessionmaker(binds=db.get_binds(app)). 

1060 """ 

1061 app = self.get_app(app) 

1062 binds = [None] + list(app.config.get('SQLALCHEMY_BINDS') or ()) 

1063 retval = {} 

1064 for bind in binds: 

1065 engine = self.get_engine(app, bind) 

1066 tables = self.get_tables_for_bind(bind) 

1067 retval.update(dict((table, engine) for table in tables)) 

1068 return retval 

1069 

1070 def _execute_for_all_tables(self, app, bind, operation, skip_tables=False): 

1071 app = self.get_app(app) 

1072 

1073 if bind == '__all__': 

1074 binds = [None] + list(app.config.get('SQLALCHEMY_BINDS') or ()) 

1075 elif isinstance(bind, string_types) or bind is None: 

1076 binds = [bind] 

1077 else: 

1078 binds = bind 

1079 

1080 for bind in binds: 

1081 extra = {} 

1082 if not skip_tables: 

1083 tables = self.get_tables_for_bind(bind) 

1084 extra['tables'] = tables 

1085 op = getattr(self.Model.metadata, operation) 

1086 op(bind=self.get_engine(app, bind), **extra) 

1087 

1088 def create_all(self, bind='__all__', app=None): 

1089 """Creates all tables. 

1090 

1091 .. versionchanged:: 0.12 

1092 Parameters were added 

1093 """ 

1094 self._execute_for_all_tables(app, bind, 'create_all') 

1095 

1096 def drop_all(self, bind='__all__', app=None): 

1097 """Drops all tables. 

1098 

1099 .. versionchanged:: 0.12 

1100 Parameters were added 

1101 """ 

1102 self._execute_for_all_tables(app, bind, 'drop_all') 

1103 

1104 def reflect(self, bind='__all__', app=None): 

1105 """Reflects tables from the database. 

1106 

1107 .. versionchanged:: 0.12 

1108 Parameters were added 

1109 """ 

1110 self._execute_for_all_tables(app, bind, 'reflect', skip_tables=True) 

1111 

1112 def __repr__(self): 

1113 return '<%s engine=%r>' % ( 

1114 self.__class__.__name__, 

1115 self.engine.url if self.app or current_app else None 

1116 ) 

1117 

1118 

1119class _BoundDeclarativeMeta(DefaultMeta): 

1120 def __init__(cls, name, bases, d): 

1121 warnings.warn(FSADeprecationWarning( 

1122 '"_BoundDeclarativeMeta" has been renamed to "DefaultMeta". The' 

1123 ' old name will be removed in 3.0.' 

1124 ), stacklevel=3) 

1125 super(_BoundDeclarativeMeta, cls).__init__(name, bases, d) 

1126 

1127 

1128class FSADeprecationWarning(DeprecationWarning): 

1129 pass 

1130 

1131 

1132warnings.simplefilter('always', FSADeprecationWarning)