1# -*- coding: utf-8 -*-
2from __future__ import absolute_import
3
4import functools
5import os
6import sys
7import time
8import warnings
9from math import ceil
10from operator import itemgetter
11from threading import Lock
12
13import sqlalchemy
14from flask import _app_ctx_stack, abort, current_app, request
15from flask.signals import Namespace
16from sqlalchemy import event, inspect, orm
17from sqlalchemy.engine.url import make_url
18from sqlalchemy.orm.exc import UnmappedClassError
19from sqlalchemy.orm.session import Session as SessionBase
20
21from ._compat import itervalues, string_types, xrange
22from .model import DefaultMeta
23from .model import Model
24from . import utils
25
26try:
27 from sqlalchemy.orm import declarative_base
28 from sqlalchemy.orm import DeclarativeMeta
29except ImportError:
30 # SQLAlchemy <= 1.3
31 from sqlalchemy.ext.declarative import declarative_base
32 from sqlalchemy.ext.declarative import DeclarativeMeta
33
34# Scope the session to the current greenlet if greenlet is available,
35# otherwise fall back to the current thread.
36try:
37 from greenlet import getcurrent as _ident_func
38except ImportError:
39 try:
40 from threading import get_ident as _ident_func
41 except ImportError:
42 # Python 2.7
43 from thread import get_ident as _ident_func
44
45__version__ = "2.5.1"
46
47# the best timer function for the platform
48if sys.platform == 'win32':
49 if sys.version_info >= (3, 3):
50 _timer = time.perf_counter
51 else:
52 _timer = time.clock
53else:
54 _timer = time.time
55
56_signals = Namespace()
57models_committed = _signals.signal('models-committed')
58before_models_committed = _signals.signal('before-models-committed')
59
60
61def _sa_url_set(url, **kwargs):
62 try:
63 url = url.set(**kwargs)
64 except AttributeError:
65 # SQLAlchemy <= 1.3
66 for key, value in kwargs.items():
67 setattr(url, key, value)
68
69 return url
70
71
72def _sa_url_query_setdefault(url, **kwargs):
73 query = dict(url.query)
74
75 for key, value in kwargs.items():
76 query.setdefault(key, value)
77
78 return _sa_url_set(url, query=query)
79
80
81def _make_table(db):
82 def _make_table(*args, **kwargs):
83 if len(args) > 1 and isinstance(args[1], db.Column):
84 args = (args[0], db.metadata) + args[1:]
85 info = kwargs.pop('info', None) or {}
86 info.setdefault('bind_key', None)
87 kwargs['info'] = info
88 return sqlalchemy.Table(*args, **kwargs)
89 return _make_table
90
91
92def _set_default_query_class(d, cls):
93 if 'query_class' not in d:
94 d['query_class'] = cls
95
96
97def _wrap_with_default_query_class(fn, cls):
98 @functools.wraps(fn)
99 def newfn(*args, **kwargs):
100 _set_default_query_class(kwargs, cls)
101 if "backref" in kwargs:
102 backref = kwargs['backref']
103 if isinstance(backref, string_types):
104 backref = (backref, {})
105 _set_default_query_class(backref[1], cls)
106 return fn(*args, **kwargs)
107 return newfn
108
109
110def _include_sqlalchemy(obj, cls):
111 for module in sqlalchemy, sqlalchemy.orm:
112 for key in module.__all__:
113 if not hasattr(obj, key):
114 setattr(obj, key, getattr(module, key))
115 # Note: obj.Table does not attempt to be a SQLAlchemy Table class.
116 obj.Table = _make_table(obj)
117 obj.relationship = _wrap_with_default_query_class(obj.relationship, cls)
118 obj.relation = _wrap_with_default_query_class(obj.relation, cls)
119 obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader, cls)
120 obj.event = event
121
122
123class _DebugQueryTuple(tuple):
124 statement = property(itemgetter(0))
125 parameters = property(itemgetter(1))
126 start_time = property(itemgetter(2))
127 end_time = property(itemgetter(3))
128 context = property(itemgetter(4))
129
130 @property
131 def duration(self):
132 return self.end_time - self.start_time
133
134 def __repr__(self):
135 return '<query statement="%s" parameters=%r duration=%.03f>' % (
136 self.statement,
137 self.parameters,
138 self.duration
139 )
140
141
142def _calling_context(app_path):
143 frm = sys._getframe(1)
144 while frm.f_back is not None:
145 name = frm.f_globals.get('__name__')
146 if name and (name == app_path or name.startswith(app_path + '.')):
147 funcname = frm.f_code.co_name
148 return '%s:%s (%s)' % (
149 frm.f_code.co_filename,
150 frm.f_lineno,
151 funcname
152 )
153 frm = frm.f_back
154 return '<unknown>'
155
156
157class SignallingSession(SessionBase):
158 """The signalling session is the default session that Flask-SQLAlchemy
159 uses. It extends the default session system with bind selection and
160 modification tracking.
161
162 If you want to use a different session you can override the
163 :meth:`SQLAlchemy.create_session` function.
164
165 .. versionadded:: 2.0
166
167 .. versionadded:: 2.1
168 The `binds` option was added, which allows a session to be joined
169 to an external transaction.
170 """
171
172 def __init__(self, db, autocommit=False, autoflush=True, **options):
173 #: The application that this session belongs to.
174 self.app = app = db.get_app()
175 track_modifications = app.config['SQLALCHEMY_TRACK_MODIFICATIONS']
176 bind = options.pop('bind', None) or db.engine
177 binds = options.pop('binds', db.get_binds(app))
178
179 if track_modifications is None or track_modifications:
180 _SessionSignalEvents.register(self)
181
182 SessionBase.__init__(
183 self, autocommit=autocommit, autoflush=autoflush,
184 bind=bind, binds=binds, **options
185 )
186
187 def get_bind(self, mapper=None, clause=None):
188 """Return the engine or connection for a given model or
189 table, using the ``__bind_key__`` if it is set.
190 """
191 # mapper is None if someone tries to just get a connection
192 if mapper is not None:
193 try:
194 # SA >= 1.3
195 persist_selectable = mapper.persist_selectable
196 except AttributeError:
197 # SA < 1.3
198 persist_selectable = mapper.mapped_table
199
200 info = getattr(persist_selectable, 'info', {})
201 bind_key = info.get('bind_key')
202 if bind_key is not None:
203 state = get_state(self.app)
204 return state.db.get_engine(self.app, bind=bind_key)
205 return SessionBase.get_bind(self, mapper, clause)
206
207
208class _SessionSignalEvents(object):
209 @classmethod
210 def register(cls, session):
211 if not hasattr(session, '_model_changes'):
212 session._model_changes = {}
213
214 event.listen(session, 'before_flush', cls.record_ops)
215 event.listen(session, 'before_commit', cls.record_ops)
216 event.listen(session, 'before_commit', cls.before_commit)
217 event.listen(session, 'after_commit', cls.after_commit)
218 event.listen(session, 'after_rollback', cls.after_rollback)
219
220 @classmethod
221 def unregister(cls, session):
222 if hasattr(session, '_model_changes'):
223 del session._model_changes
224
225 event.remove(session, 'before_flush', cls.record_ops)
226 event.remove(session, 'before_commit', cls.record_ops)
227 event.remove(session, 'before_commit', cls.before_commit)
228 event.remove(session, 'after_commit', cls.after_commit)
229 event.remove(session, 'after_rollback', cls.after_rollback)
230
231 @staticmethod
232 def record_ops(session, flush_context=None, instances=None):
233 try:
234 d = session._model_changes
235 except AttributeError:
236 return
237
238 for targets, operation in ((session.new, 'insert'), (session.dirty, 'update'), (session.deleted, 'delete')):
239 for target in targets:
240 state = inspect(target)
241 key = state.identity_key if state.has_identity else id(target)
242 d[key] = (target, operation)
243
244 @staticmethod
245 def before_commit(session):
246 try:
247 d = session._model_changes
248 except AttributeError:
249 return
250
251 if d:
252 before_models_committed.send(session.app, changes=list(d.values()))
253
254 @staticmethod
255 def after_commit(session):
256 try:
257 d = session._model_changes
258 except AttributeError:
259 return
260
261 if d:
262 models_committed.send(session.app, changes=list(d.values()))
263 d.clear()
264
265 @staticmethod
266 def after_rollback(session):
267 try:
268 d = session._model_changes
269 except AttributeError:
270 return
271
272 d.clear()
273
274
275class _EngineDebuggingSignalEvents(object):
276 """Sets up handlers for two events that let us track the execution time of
277 queries."""
278
279 def __init__(self, engine, import_name):
280 self.engine = engine
281 self.app_package = import_name
282
283 def register(self):
284 event.listen(
285 self.engine, 'before_cursor_execute', self.before_cursor_execute
286 )
287 event.listen(
288 self.engine, 'after_cursor_execute', self.after_cursor_execute
289 )
290
291 def before_cursor_execute(
292 self, conn, cursor, statement, parameters, context, executemany
293 ):
294 if current_app:
295 context._query_start_time = _timer()
296
297 def after_cursor_execute(
298 self, conn, cursor, statement, parameters, context, executemany
299 ):
300 if current_app:
301 try:
302 queries = _app_ctx_stack.top.sqlalchemy_queries
303 except AttributeError:
304 queries = _app_ctx_stack.top.sqlalchemy_queries = []
305
306 queries.append(_DebugQueryTuple((
307 statement, parameters, context._query_start_time, _timer(),
308 _calling_context(self.app_package)
309 )))
310
311
312def get_debug_queries():
313 """In debug mode Flask-SQLAlchemy will log all the SQL queries sent
314 to the database. This information is available until the end of request
315 which makes it possible to easily ensure that the SQL generated is the
316 one expected on errors or in unittesting. If you don't want to enable
317 the DEBUG mode for your unittests you can also enable the query
318 recording by setting the ``'SQLALCHEMY_RECORD_QUERIES'`` config variable
319 to `True`. This is automatically enabled if Flask is in testing mode.
320
321 The value returned will be a list of named tuples with the following
322 attributes:
323
324 `statement`
325 The SQL statement issued
326
327 `parameters`
328 The parameters for the SQL statement
329
330 `start_time` / `end_time`
331 Time the query started / the results arrived. Please keep in mind
332 that the timer function used depends on your platform. These
333 values are only useful for sorting or comparing. They do not
334 necessarily represent an absolute timestamp.
335
336 `duration`
337 Time the query took in seconds
338
339 `context`
340 A string giving a rough estimation of where in your application
341 query was issued. The exact format is undefined so don't try
342 to reconstruct filename or function name.
343 """
344 return getattr(_app_ctx_stack.top, 'sqlalchemy_queries', [])
345
346
347class Pagination(object):
348 """Internal helper class returned by :meth:`BaseQuery.paginate`. You
349 can also construct it from any other SQLAlchemy query object if you are
350 working with other libraries. Additionally it is possible to pass `None`
351 as query object in which case the :meth:`prev` and :meth:`next` will
352 no longer work.
353 """
354
355 def __init__(self, query, page, per_page, total, items):
356 #: the unlimited query object that was used to create this
357 #: pagination object.
358 self.query = query
359 #: the current page number (1 indexed)
360 self.page = page
361 #: the number of items to be displayed on a page.
362 self.per_page = per_page
363 #: the total number of items matching the query
364 self.total = total
365 #: the items for the current page
366 self.items = items
367
368 @property
369 def pages(self):
370 """The total number of pages"""
371 if self.per_page == 0:
372 pages = 0
373 else:
374 pages = int(ceil(self.total / float(self.per_page)))
375 return pages
376
377 def prev(self, error_out=False):
378 """Returns a :class:`Pagination` object for the previous page."""
379 assert self.query is not None, 'a query object is required ' \
380 'for this method to work'
381 return self.query.paginate(self.page - 1, self.per_page, error_out)
382
383 @property
384 def prev_num(self):
385 """Number of the previous page."""
386 if not self.has_prev:
387 return None
388 return self.page - 1
389
390 @property
391 def has_prev(self):
392 """True if a previous page exists"""
393 return self.page > 1
394
395 def next(self, error_out=False):
396 """Returns a :class:`Pagination` object for the next page."""
397 assert self.query is not None, 'a query object is required ' \
398 'for this method to work'
399 return self.query.paginate(self.page + 1, self.per_page, error_out)
400
401 @property
402 def has_next(self):
403 """True if a next page exists."""
404 return self.page < self.pages
405
406 @property
407 def next_num(self):
408 """Number of the next page"""
409 if not self.has_next:
410 return None
411 return self.page + 1
412
413 def iter_pages(self, left_edge=2, left_current=2,
414 right_current=5, right_edge=2):
415 """Iterates over the page numbers in the pagination. The four
416 parameters control the thresholds how many numbers should be produced
417 from the sides. Skipped page numbers are represented as `None`.
418 This is how you could render such a pagination in the templates:
419
420 .. sourcecode:: html+jinja
421
422 {% macro render_pagination(pagination, endpoint) %}
423 <div class=pagination>
424 {%- for page in pagination.iter_pages() %}
425 {% if page %}
426 {% if page != pagination.page %}
427 <a href="{{ url_for(endpoint, page=page) }}">{{ page }}</a>
428 {% else %}
429 <strong>{{ page }}</strong>
430 {% endif %}
431 {% else %}
432 <span class=ellipsis>…</span>
433 {% endif %}
434 {%- endfor %}
435 </div>
436 {% endmacro %}
437 """
438 last = 0
439 for num in xrange(1, self.pages + 1):
440 if num <= left_edge or \
441 (num > self.page - left_current - 1 and
442 num < self.page + right_current) or \
443 num > self.pages - right_edge:
444 if last + 1 != num:
445 yield None
446 yield num
447 last = num
448
449
450class BaseQuery(orm.Query):
451 """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with convenience methods for querying in a web application.
452
453 This is the default :attr:`~Model.query` object used for models, and exposed as :attr:`~SQLAlchemy.Query`.
454 Override the query class for an individual model by subclassing this and setting :attr:`~Model.query_class`.
455 """
456
457 def get_or_404(self, ident, description=None):
458 """Like :meth:`get` but aborts with 404 if not found instead of returning ``None``."""
459
460 rv = self.get(ident)
461 if rv is None:
462 abort(404, description=description)
463 return rv
464
465 def first_or_404(self, description=None):
466 """Like :meth:`first` but aborts with 404 if not found instead of returning ``None``."""
467
468 rv = self.first()
469 if rv is None:
470 abort(404, description=description)
471 return rv
472
473 def paginate(self, page=None, per_page=None, error_out=True, max_per_page=None):
474 """Returns ``per_page`` items from page ``page``.
475
476 If ``page`` or ``per_page`` are ``None``, they will be retrieved from
477 the request query. If ``max_per_page`` is specified, ``per_page`` will
478 be limited to that value. If there is no request or they aren't in the
479 query, they default to 1 and 20 respectively.
480
481 When ``error_out`` is ``True`` (default), the following rules will
482 cause a 404 response:
483
484 * No items are found and ``page`` is not 1.
485 * ``page`` is less than 1, or ``per_page`` is negative.
486 * ``page`` or ``per_page`` are not ints.
487
488 When ``error_out`` is ``False``, ``page`` and ``per_page`` default to
489 1 and 20 respectively.
490
491 Returns a :class:`Pagination` object.
492 """
493
494 if request:
495 if page is None:
496 try:
497 page = int(request.args.get('page', 1))
498 except (TypeError, ValueError):
499 if error_out:
500 abort(404)
501
502 page = 1
503
504 if per_page is None:
505 try:
506 per_page = int(request.args.get('per_page', 20))
507 except (TypeError, ValueError):
508 if error_out:
509 abort(404)
510
511 per_page = 20
512 else:
513 if page is None:
514 page = 1
515
516 if per_page is None:
517 per_page = 20
518
519 if max_per_page is not None:
520 per_page = min(per_page, max_per_page)
521
522 if page < 1:
523 if error_out:
524 abort(404)
525 else:
526 page = 1
527
528 if per_page < 0:
529 if error_out:
530 abort(404)
531 else:
532 per_page = 20
533
534 items = self.limit(per_page).offset((page - 1) * per_page).all()
535
536 if not items and page != 1 and error_out:
537 abort(404)
538
539 total = self.order_by(None).count()
540
541 return Pagination(self, page, per_page, total, items)
542
543
544class _QueryProperty(object):
545 def __init__(self, sa):
546 self.sa = sa
547
548 def __get__(self, obj, type):
549 try:
550 mapper = orm.class_mapper(type)
551 if mapper:
552 return type.query_class(mapper, session=self.sa.session())
553 except UnmappedClassError:
554 return None
555
556
557def _record_queries(app):
558 if app.debug:
559 return True
560 rq = app.config['SQLALCHEMY_RECORD_QUERIES']
561 if rq is not None:
562 return rq
563 return bool(app.config.get('TESTING'))
564
565
566class _EngineConnector(object):
567
568 def __init__(self, sa, app, bind=None):
569 self._sa = sa
570 self._app = app
571 self._engine = None
572 self._connected_for = None
573 self._bind = bind
574 self._lock = Lock()
575
576 def get_uri(self):
577 if self._bind is None:
578 return self._app.config['SQLALCHEMY_DATABASE_URI']
579 binds = self._app.config.get('SQLALCHEMY_BINDS') or ()
580 assert self._bind in binds, \
581 'Bind %r is not specified. Set it in the SQLALCHEMY_BINDS ' \
582 'configuration variable' % self._bind
583 return binds[self._bind]
584
585 def get_engine(self):
586 with self._lock:
587 uri = self.get_uri()
588 echo = self._app.config['SQLALCHEMY_ECHO']
589 if (uri, echo) == self._connected_for:
590 return self._engine
591
592 sa_url = make_url(uri)
593 sa_url, options = self.get_options(sa_url, echo)
594 self._engine = rv = self._sa.create_engine(sa_url, options)
595
596 if _record_queries(self._app):
597 _EngineDebuggingSignalEvents(self._engine,
598 self._app.import_name).register()
599
600 self._connected_for = (uri, echo)
601
602 return rv
603
604 def get_options(self, sa_url, echo):
605 options = {}
606
607 options = self._sa.apply_pool_defaults(self._app, options)
608 sa_url, options = self._sa.apply_driver_hacks(self._app, sa_url, options)
609
610 if echo:
611 options['echo'] = echo
612
613 # Give the config options set by a developer explicitly priority
614 # over decisions FSA makes.
615 options.update(self._app.config['SQLALCHEMY_ENGINE_OPTIONS'])
616
617 # Give options set in SQLAlchemy.__init__() ultimate priority
618 options.update(self._sa._engine_options)
619
620 return sa_url, options
621
622
623def get_state(app):
624 """Gets the state for the application"""
625 assert 'sqlalchemy' in app.extensions, \
626 'The sqlalchemy extension was not registered to the current ' \
627 'application. Please make sure to call init_app() first.'
628 return app.extensions['sqlalchemy']
629
630
631class _SQLAlchemyState(object):
632 """Remembers configuration for the (db, app) tuple."""
633
634 def __init__(self, db):
635 self.db = db
636 self.connectors = {}
637
638
639class SQLAlchemy(object):
640 """This class is used to control the SQLAlchemy integration to one
641 or more Flask applications. Depending on how you initialize the
642 object it is usable right away or will attach as needed to a
643 Flask application.
644
645 There are two usage modes which work very similarly. One is binding
646 the instance to a very specific Flask application::
647
648 app = Flask(__name__)
649 db = SQLAlchemy(app)
650
651 The second possibility is to create the object once and configure the
652 application later to support it::
653
654 db = SQLAlchemy()
655
656 def create_app():
657 app = Flask(__name__)
658 db.init_app(app)
659 return app
660
661 The difference between the two is that in the first case methods like
662 :meth:`create_all` and :meth:`drop_all` will work all the time but in
663 the second case a :meth:`flask.Flask.app_context` has to exist.
664
665 By default Flask-SQLAlchemy will apply some backend-specific settings
666 to improve your experience with them.
667
668 As of SQLAlchemy 0.6 SQLAlchemy
669 will probe the library for native unicode support. If it detects
670 unicode it will let the library handle that, otherwise do that itself.
671 Sometimes this detection can fail in which case you might want to set
672 ``use_native_unicode`` (or the ``SQLALCHEMY_NATIVE_UNICODE`` configuration
673 key) to ``False``. Note that the configuration key overrides the
674 value you pass to the constructor. Direct support for ``use_native_unicode``
675 and SQLALCHEMY_NATIVE_UNICODE are deprecated as of v2.4 and will be removed
676 in v3.0. ``engine_options`` and ``SQLALCHEMY_ENGINE_OPTIONS`` may be used
677 instead.
678
679 This class also provides access to all the SQLAlchemy functions and classes
680 from the :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` modules. So you can
681 declare models like this::
682
683 class User(db.Model):
684 username = db.Column(db.String(80), unique=True)
685 pw_hash = db.Column(db.String(80))
686
687 You can still use :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, but
688 note that Flask-SQLAlchemy customizations are available only through an
689 instance of this :class:`SQLAlchemy` class. Query classes default to
690 :class:`BaseQuery` for `db.Query`, `db.Model.query_class`, and the default
691 query_class for `db.relationship` and `db.backref`. If you use these
692 interfaces through :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly,
693 the default query class will be that of :mod:`sqlalchemy`.
694
695 .. admonition:: Check types carefully
696
697 Don't perform type or `isinstance` checks against `db.Table`, which
698 emulates `Table` behavior but is not a class. `db.Table` exposes the
699 `Table` interface, but is a function which allows omission of metadata.
700
701 The ``session_options`` parameter, if provided, is a dict of parameters
702 to be passed to the session constructor. See :class:`~sqlalchemy.orm.session.Session`
703 for the standard options.
704
705 The ``engine_options`` parameter, if provided, is a dict of parameters
706 to be passed to create engine. See :func:`~sqlalchemy.create_engine`
707 for the standard options. The values given here will be merged with and
708 override anything set in the ``'SQLALCHEMY_ENGINE_OPTIONS'`` config
709 variable or othewise set by this library.
710
711 .. versionadded:: 0.10
712 The `session_options` parameter was added.
713
714 .. versionadded:: 0.16
715 `scopefunc` is now accepted on `session_options`. It allows specifying
716 a custom function which will define the SQLAlchemy session's scoping.
717
718 .. versionadded:: 2.1
719 The `metadata` parameter was added. This allows for setting custom
720 naming conventions among other, non-trivial things.
721
722 The `query_class` parameter was added, to allow customisation
723 of the query class, in place of the default of :class:`BaseQuery`.
724
725 The `model_class` parameter was added, which allows a custom model
726 class to be used in place of :class:`Model`.
727
728 .. versionchanged:: 2.1
729 Utilise the same query class across `session`, `Model.query` and `Query`.
730
731 .. versionadded:: 2.4
732 The `engine_options` parameter was added.
733
734 .. versionchanged:: 2.4
735 The `use_native_unicode` parameter was deprecated.
736
737 .. versionchanged:: 2.4.3
738 ``COMMIT_ON_TEARDOWN`` is deprecated and will be removed in
739 version 3.1. Call ``db.session.commit()`` directly instead.
740 """
741
742 #: Default query class used by :attr:`Model.query` and other queries.
743 #: Customize this by passing ``query_class`` to :func:`SQLAlchemy`.
744 #: Defaults to :class:`BaseQuery`.
745 Query = None
746
747 def __init__(self, app=None, use_native_unicode=True, session_options=None,
748 metadata=None, query_class=BaseQuery, model_class=Model,
749 engine_options=None):
750
751 self.use_native_unicode = use_native_unicode
752 self.Query = query_class
753 self.session = self.create_scoped_session(session_options)
754 self.Model = self.make_declarative_base(model_class, metadata)
755 self._engine_lock = Lock()
756 self.app = app
757 self._engine_options = engine_options or {}
758 _include_sqlalchemy(self, query_class)
759
760 if app is not None:
761 self.init_app(app)
762
763 @property
764 def metadata(self):
765 """The metadata associated with ``db.Model``."""
766
767 return self.Model.metadata
768
769 def create_scoped_session(self, options=None):
770 """Create a :class:`~sqlalchemy.orm.scoping.scoped_session`
771 on the factory from :meth:`create_session`.
772
773 An extra key ``'scopefunc'`` can be set on the ``options`` dict to
774 specify a custom scope function. If it's not provided, Flask's app
775 context stack identity is used. This will ensure that sessions are
776 created and removed with the request/response cycle, and should be fine
777 in most cases.
778
779 :param options: dict of keyword arguments passed to session class in
780 ``create_session``
781 """
782
783 if options is None:
784 options = {}
785
786 scopefunc = options.pop('scopefunc', _ident_func)
787 options.setdefault('query_cls', self.Query)
788 return orm.scoped_session(
789 self.create_session(options), scopefunc=scopefunc
790 )
791
792 def create_session(self, options):
793 """Create the session factory used by :meth:`create_scoped_session`.
794
795 The factory **must** return an object that SQLAlchemy recognizes as a session,
796 or registering session events may raise an exception.
797
798 Valid factories include a :class:`~sqlalchemy.orm.session.Session`
799 class or a :class:`~sqlalchemy.orm.session.sessionmaker`.
800
801 The default implementation creates a ``sessionmaker`` for :class:`SignallingSession`.
802
803 :param options: dict of keyword arguments passed to session class
804 """
805
806 return orm.sessionmaker(class_=SignallingSession, db=self, **options)
807
808 def make_declarative_base(self, model, metadata=None):
809 """Creates the declarative base that all models will inherit from.
810
811 :param model: base model class (or a tuple of base classes) to pass
812 to :func:`~sqlalchemy.ext.declarative.declarative_base`. Or a class
813 returned from ``declarative_base``, in which case a new base class
814 is not created.
815 :param metadata: :class:`~sqlalchemy.MetaData` instance to use, or
816 none to use SQLAlchemy's default.
817
818 .. versionchanged 2.3.0::
819 ``model`` can be an existing declarative base in order to support
820 complex customization such as changing the metaclass.
821 """
822 if not isinstance(model, DeclarativeMeta):
823 model = declarative_base(
824 cls=model,
825 name='Model',
826 metadata=metadata,
827 metaclass=DefaultMeta
828 )
829
830 # if user passed in a declarative base and a metaclass for some reason,
831 # make sure the base uses the metaclass
832 if metadata is not None and model.metadata is not metadata:
833 model.metadata = metadata
834
835 if not getattr(model, 'query_class', None):
836 model.query_class = self.Query
837
838 model.query = _QueryProperty(self)
839 return model
840
841 def init_app(self, app):
842 """This callback can be used to initialize an application for the
843 use with this database setup. Never use a database in the context
844 of an application not initialized that way or connections will
845 leak.
846 """
847 if (
848 'SQLALCHEMY_DATABASE_URI' not in app.config and
849 'SQLALCHEMY_BINDS' not in app.config
850 ):
851 warnings.warn(
852 'Neither SQLALCHEMY_DATABASE_URI nor SQLALCHEMY_BINDS is set. '
853 'Defaulting SQLALCHEMY_DATABASE_URI to "sqlite:///:memory:".'
854 )
855
856 app.config.setdefault('SQLALCHEMY_DATABASE_URI', 'sqlite:///:memory:')
857 app.config.setdefault('SQLALCHEMY_BINDS', None)
858 app.config.setdefault('SQLALCHEMY_NATIVE_UNICODE', None)
859 app.config.setdefault('SQLALCHEMY_ECHO', False)
860 app.config.setdefault('SQLALCHEMY_RECORD_QUERIES', None)
861 app.config.setdefault('SQLALCHEMY_POOL_SIZE', None)
862 app.config.setdefault('SQLALCHEMY_POOL_TIMEOUT', None)
863 app.config.setdefault('SQLALCHEMY_POOL_RECYCLE', None)
864 app.config.setdefault('SQLALCHEMY_MAX_OVERFLOW', None)
865 app.config.setdefault('SQLALCHEMY_COMMIT_ON_TEARDOWN', False)
866 track_modifications = app.config.setdefault(
867 'SQLALCHEMY_TRACK_MODIFICATIONS', None
868 )
869 app.config.setdefault('SQLALCHEMY_ENGINE_OPTIONS', {})
870
871 if track_modifications is None:
872 warnings.warn(FSADeprecationWarning(
873 'SQLALCHEMY_TRACK_MODIFICATIONS adds significant overhead and '
874 'will be disabled by default in the future. Set it to True '
875 'or False to suppress this warning.'
876 ))
877
878 # Deprecation warnings for config keys that should be replaced by SQLALCHEMY_ENGINE_OPTIONS.
879 utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_SIZE', 'pool_size')
880 utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_TIMEOUT', 'pool_timeout')
881 utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_RECYCLE', 'pool_recycle')
882 utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_MAX_OVERFLOW', 'max_overflow')
883
884 app.extensions['sqlalchemy'] = _SQLAlchemyState(self)
885
886 @app.teardown_appcontext
887 def shutdown_session(response_or_exc):
888 if app.config['SQLALCHEMY_COMMIT_ON_TEARDOWN']:
889 warnings.warn(
890 "'COMMIT_ON_TEARDOWN' is deprecated and will be"
891 " removed in version 3.1. Call"
892 " 'db.session.commit()'` directly instead.",
893 DeprecationWarning,
894 )
895
896 if response_or_exc is None:
897 self.session.commit()
898
899 self.session.remove()
900 return response_or_exc
901
902 def apply_pool_defaults(self, app, options):
903 """
904 .. versionchanged:: 2.5
905 Returns the ``options`` dict, for consistency with
906 :meth:`apply_driver_hacks`.
907 """
908 def _setdefault(optionkey, configkey):
909 value = app.config[configkey]
910 if value is not None:
911 options[optionkey] = value
912 _setdefault('pool_size', 'SQLALCHEMY_POOL_SIZE')
913 _setdefault('pool_timeout', 'SQLALCHEMY_POOL_TIMEOUT')
914 _setdefault('pool_recycle', 'SQLALCHEMY_POOL_RECYCLE')
915 _setdefault('max_overflow', 'SQLALCHEMY_MAX_OVERFLOW')
916 return options
917
918 def apply_driver_hacks(self, app, sa_url, options):
919 """This method is called before engine creation and used to inject
920 driver specific hacks into the options. The `options` parameter is
921 a dictionary of keyword arguments that will then be used to call
922 the :func:`sqlalchemy.create_engine` function.
923
924 The default implementation provides some saner defaults for things
925 like pool sizes for MySQL and sqlite. Also it injects the setting of
926 `SQLALCHEMY_NATIVE_UNICODE`.
927
928 .. versionchanged:: 2.5
929 Returns ``(sa_url, options)``. SQLAlchemy 1.4 made the URL
930 immutable, so any changes to it must now be passed back up
931 to the original caller.
932 """
933 if sa_url.drivername.startswith('mysql'):
934 sa_url = _sa_url_query_setdefault(sa_url, charset="utf8")
935
936 if sa_url.drivername != 'mysql+gaerdbms':
937 options.setdefault('pool_size', 10)
938 options.setdefault('pool_recycle', 7200)
939 elif sa_url.drivername == 'sqlite':
940 pool_size = options.get('pool_size')
941 detected_in_memory = False
942 if sa_url.database in (None, '', ':memory:'):
943 detected_in_memory = True
944 from sqlalchemy.pool import StaticPool
945 options['poolclass'] = StaticPool
946 if 'connect_args' not in options:
947 options['connect_args'] = {}
948 options['connect_args']['check_same_thread'] = False
949
950 # we go to memory and the pool size was explicitly set
951 # to 0 which is fail. Let the user know that
952 if pool_size == 0:
953 raise RuntimeError('SQLite in memory database with an '
954 'empty queue not possible due to data '
955 'loss.')
956 # if pool size is None or explicitly set to 0 we assume the
957 # user did not want a queue for this sqlite connection and
958 # hook in the null pool.
959 elif not pool_size:
960 from sqlalchemy.pool import NullPool
961 options['poolclass'] = NullPool
962
963 # if it's not an in memory database we make the path absolute.
964 if not detected_in_memory:
965 sa_url = _sa_url_set(
966 sa_url, database=os.path.join(app.root_path, sa_url.database)
967 )
968
969 unu = app.config['SQLALCHEMY_NATIVE_UNICODE']
970 if unu is None:
971 unu = self.use_native_unicode
972 if not unu:
973 options['use_native_unicode'] = False
974
975 if app.config['SQLALCHEMY_NATIVE_UNICODE'] is not None:
976 warnings.warn(
977 "The 'SQLALCHEMY_NATIVE_UNICODE' config option is deprecated and will be removed in"
978 " v3.0. Use 'SQLALCHEMY_ENGINE_OPTIONS' instead.",
979 DeprecationWarning
980 )
981 if not self.use_native_unicode:
982 warnings.warn(
983 "'use_native_unicode' is deprecated and will be removed in v3.0."
984 " Use the 'engine_options' parameter instead.",
985 DeprecationWarning
986 )
987
988 return sa_url, options
989
990 @property
991 def engine(self):
992 """Gives access to the engine. If the database configuration is bound
993 to a specific application (initialized with an application) this will
994 always return a database connection. If however the current application
995 is used this might raise a :exc:`RuntimeError` if no application is
996 active at the moment.
997 """
998 return self.get_engine()
999
1000 def make_connector(self, app=None, bind=None):
1001 """Creates the connector for a given state and bind."""
1002 return _EngineConnector(self, self.get_app(app), bind)
1003
1004 def get_engine(self, app=None, bind=None):
1005 """Returns a specific engine."""
1006
1007 app = self.get_app(app)
1008 state = get_state(app)
1009
1010 with self._engine_lock:
1011 connector = state.connectors.get(bind)
1012
1013 if connector is None:
1014 connector = self.make_connector(app, bind)
1015 state.connectors[bind] = connector
1016
1017 return connector.get_engine()
1018
1019 def create_engine(self, sa_url, engine_opts):
1020 """
1021 Override this method to have final say over how the SQLAlchemy engine
1022 is created.
1023
1024 In most cases, you will want to use ``'SQLALCHEMY_ENGINE_OPTIONS'``
1025 config variable or set ``engine_options`` for :func:`SQLAlchemy`.
1026 """
1027 return sqlalchemy.create_engine(sa_url, **engine_opts)
1028
1029 def get_app(self, reference_app=None):
1030 """Helper method that implements the logic to look up an
1031 application."""
1032
1033 if reference_app is not None:
1034 return reference_app
1035
1036 if current_app:
1037 return current_app._get_current_object()
1038
1039 if self.app is not None:
1040 return self.app
1041
1042 raise RuntimeError(
1043 'No application found. Either work inside a view function or push'
1044 ' an application context. See'
1045 ' http://flask-sqlalchemy.pocoo.org/contexts/.'
1046 )
1047
1048 def get_tables_for_bind(self, bind=None):
1049 """Returns a list of all tables relevant for a bind."""
1050 result = []
1051 for table in itervalues(self.Model.metadata.tables):
1052 if table.info.get('bind_key') == bind:
1053 result.append(table)
1054 return result
1055
1056 def get_binds(self, app=None):
1057 """Returns a dictionary with a table->engine mapping.
1058
1059 This is suitable for use of sessionmaker(binds=db.get_binds(app)).
1060 """
1061 app = self.get_app(app)
1062 binds = [None] + list(app.config.get('SQLALCHEMY_BINDS') or ())
1063 retval = {}
1064 for bind in binds:
1065 engine = self.get_engine(app, bind)
1066 tables = self.get_tables_for_bind(bind)
1067 retval.update(dict((table, engine) for table in tables))
1068 return retval
1069
1070 def _execute_for_all_tables(self, app, bind, operation, skip_tables=False):
1071 app = self.get_app(app)
1072
1073 if bind == '__all__':
1074 binds = [None] + list(app.config.get('SQLALCHEMY_BINDS') or ())
1075 elif isinstance(bind, string_types) or bind is None:
1076 binds = [bind]
1077 else:
1078 binds = bind
1079
1080 for bind in binds:
1081 extra = {}
1082 if not skip_tables:
1083 tables = self.get_tables_for_bind(bind)
1084 extra['tables'] = tables
1085 op = getattr(self.Model.metadata, operation)
1086 op(bind=self.get_engine(app, bind), **extra)
1087
1088 def create_all(self, bind='__all__', app=None):
1089 """Creates all tables.
1090
1091 .. versionchanged:: 0.12
1092 Parameters were added
1093 """
1094 self._execute_for_all_tables(app, bind, 'create_all')
1095
1096 def drop_all(self, bind='__all__', app=None):
1097 """Drops all tables.
1098
1099 .. versionchanged:: 0.12
1100 Parameters were added
1101 """
1102 self._execute_for_all_tables(app, bind, 'drop_all')
1103
1104 def reflect(self, bind='__all__', app=None):
1105 """Reflects tables from the database.
1106
1107 .. versionchanged:: 0.12
1108 Parameters were added
1109 """
1110 self._execute_for_all_tables(app, bind, 'reflect', skip_tables=True)
1111
1112 def __repr__(self):
1113 return '<%s engine=%r>' % (
1114 self.__class__.__name__,
1115 self.engine.url if self.app or current_app else None
1116 )
1117
1118
1119class _BoundDeclarativeMeta(DefaultMeta):
1120 def __init__(cls, name, bases, d):
1121 warnings.warn(FSADeprecationWarning(
1122 '"_BoundDeclarativeMeta" has been renamed to "DefaultMeta". The'
1123 ' old name will be removed in 3.0.'
1124 ), stacklevel=3)
1125 super(_BoundDeclarativeMeta, cls).__init__(name, bases, d)
1126
1127
1128class FSADeprecationWarning(DeprecationWarning):
1129 pass
1130
1131
1132warnings.simplefilter('always', FSADeprecationWarning)