1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import logging
21import os
22import time
23import traceback
24
25from sqlalchemy import event, exc
26
27from airflow.configuration import conf
28from airflow.utils.sqlalchemy import get_orm_mapper
29
30log = logging.getLogger(__name__)
31
32
33def setup_event_handlers(engine):
34 """Setups event handlers."""
35 from airflow.models import import_all_models
36
37 event.listen(get_orm_mapper(), "before_configured", import_all_models, once=True)
38
39 @event.listens_for(engine, "connect")
40 def connect(dbapi_connection, connection_record):
41 connection_record.info["pid"] = os.getpid()
42
43 if engine.dialect.name == "sqlite":
44
45 @event.listens_for(engine, "connect")
46 def set_sqlite_pragma(dbapi_connection, connection_record):
47 cursor = dbapi_connection.cursor()
48 cursor.execute("PRAGMA foreign_keys=ON")
49 cursor.execute("PRAGMA journal_mode=WAL")
50 cursor.close()
51
52 # this ensures coherence in mysql when storing datetimes (not required for postgres)
53 if engine.dialect.name == "mysql":
54
55 @event.listens_for(engine, "connect")
56 def set_mysql_timezone(dbapi_connection, connection_record):
57 cursor = dbapi_connection.cursor()
58 cursor.execute("SET time_zone = '+00:00'")
59 cursor.close()
60
61 @event.listens_for(engine, "checkout")
62 def checkout(dbapi_connection, connection_record, connection_proxy):
63 pid = os.getpid()
64 if connection_record.info["pid"] != pid:
65 connection_record.connection = connection_proxy.connection = None
66 raise exc.DisconnectionError(
67 f"Connection record belongs to pid {connection_record.info['pid']}, "
68 f"attempting to check out in pid {pid}"
69 )
70
71 if conf.getboolean("debug", "sqlalchemy_stats", fallback=False):
72
73 @event.listens_for(engine, "before_cursor_execute")
74 def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
75 conn.info.setdefault("query_start_time", []).append(time.perf_counter())
76
77 @event.listens_for(engine, "after_cursor_execute")
78 def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
79 total = time.perf_counter() - conn.info["query_start_time"].pop()
80 file_name = [
81 f"'{f.name}':{f.filename}:{f.lineno}"
82 for f in traceback.extract_stack()
83 if "sqlalchemy" not in f.filename
84 ][-1]
85 stack = [f for f in traceback.extract_stack() if "sqlalchemy" not in f.filename]
86 stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}" for f in stack][-3:])
87 conn.info.setdefault("query_start_time", []).append(time.monotonic())
88 log.info(
89 "@SQLALCHEMY %s |$ %s |$ %s |$ %s ",
90 total,
91 file_name,
92 stack_info,
93 statement.replace("\n", " "),
94 )