Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/utils/sqlalchemy.py: 35%
175 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import copy
21import datetime
22import json
23import logging
24from typing import Any, Iterable
26import pendulum
27from dateutil import relativedelta
28from sqlalchemy import TIMESTAMP, PickleType, and_, event, false, nullsfirst, or_, true, tuple_
29from sqlalchemy.dialects import mssql, mysql
30from sqlalchemy.exc import OperationalError
31from sqlalchemy.orm.session import Session
32from sqlalchemy.sql import ColumnElement
33from sqlalchemy.sql.expression import ColumnOperators
34from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
36from airflow import settings
37from airflow.configuration import conf
38from airflow.serialization.enums import Encoding
40log = logging.getLogger(__name__)
42utc = pendulum.tz.timezone("UTC")
44using_mysql = conf.get_mandatory_value("database", "sql_alchemy_conn").lower().startswith("mysql")
47class UtcDateTime(TypeDecorator):
48 """
49 Almost equivalent to :class:`~sqlalchemy.types.TIMESTAMP` with
50 ``timezone=True`` option, but it differs from that by:
52 - Never silently take naive :class:`~datetime.datetime`, instead it
53 always raise :exc:`ValueError` unless time zone aware value.
54 - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo`
55 is always converted to UTC.
56 - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`,
57 it never return naive :class:`~datetime.datetime`, but time zone
58 aware value, even with SQLite or MySQL.
59 - Always returns TIMESTAMP in UTC
61 """
63 impl = TIMESTAMP(timezone=True)
65 cache_ok = True
67 def process_bind_param(self, value, dialect):
68 if value is not None:
69 if not isinstance(value, datetime.datetime):
70 raise TypeError("expected datetime.datetime, not " + repr(value))
71 elif value.tzinfo is None:
72 raise ValueError("naive datetime is disallowed")
73 # For mysql we should store timestamps as naive values
74 # Timestamp in MYSQL is not timezone aware. In MySQL 5.6
75 # timezone added at the end is ignored but in MySQL 5.7
76 # inserting timezone value fails with 'invalid-date'
77 # See https://issues.apache.org/jira/browse/AIRFLOW-7001
78 if using_mysql:
79 from airflow.utils.timezone import make_naive
81 return make_naive(value, timezone=utc)
82 return value.astimezone(utc)
83 return None
85 def process_result_value(self, value, dialect):
86 """
87 Processes DateTimes from the DB making sure it is always
88 returning UTC. Not using timezone.convert_to_utc as that
89 converts to configured TIMEZONE while the DB might be
90 running with some other setting. We assume UTC datetimes
91 in the database.
92 """
93 if value is not None:
94 if value.tzinfo is None:
95 value = value.replace(tzinfo=utc)
96 else:
97 value = value.astimezone(utc)
99 return value
101 def load_dialect_impl(self, dialect):
102 if dialect.name == "mssql":
103 return mssql.DATETIME2(precision=6)
104 elif dialect.name == "mysql":
105 return mysql.TIMESTAMP(fsp=6)
106 return super().load_dialect_impl(dialect)
109class ExtendedJSON(TypeDecorator):
110 """
111 A version of the JSON column that uses the Airflow extended JSON
112 serialization provided by airflow.serialization.
113 """
115 impl = Text
117 cache_ok = True
119 def db_supports_json(self):
120 """Checks if the database supports JSON (i.e. is NOT MSSQL)"""
121 return not conf.get("database", "sql_alchemy_conn").startswith("mssql")
123 def load_dialect_impl(self, dialect) -> TypeEngine:
124 if self.db_supports_json():
125 return dialect.type_descriptor(JSON)
126 return dialect.type_descriptor(UnicodeText)
128 def process_bind_param(self, value, dialect):
129 from airflow.serialization.serialized_objects import BaseSerialization
131 if value is None:
132 return None
134 # First, encode it into our custom JSON-targeted dict format
135 value = BaseSerialization.serialize(value)
137 # Then, if the database does not have native JSON support, encode it again as a string
138 if not self.db_supports_json():
139 value = json.dumps(value)
141 return value
143 def process_result_value(self, value, dialect):
144 from airflow.serialization.serialized_objects import BaseSerialization
146 if value is None:
147 return None
149 # Deserialize from a string first if needed
150 if not self.db_supports_json():
151 value = json.loads(value)
153 return BaseSerialization.deserialize(value)
156class ExecutorConfigType(PickleType):
157 """
158 Adds special handling for K8s executor config. If we unpickle a k8s object that was
159 pickled under an earlier k8s library version, then the unpickled object may throw an error
160 when to_dict is called. To be more tolerant of version changes we convert to JSON using
161 Airflow's serializer before pickling.
162 """
164 cache_ok = True
166 def bind_processor(self, dialect):
168 from airflow.serialization.serialized_objects import BaseSerialization
170 super_process = super().bind_processor(dialect)
172 def process(value):
173 val_copy = copy.copy(value)
174 if isinstance(val_copy, dict) and "pod_override" in val_copy:
175 val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"])
176 return super_process(val_copy)
178 return process
180 def result_processor(self, dialect, coltype):
181 from airflow.serialization.serialized_objects import BaseSerialization
183 super_process = super().result_processor(dialect, coltype)
185 def process(value):
186 value = super_process(value) # unpickle
188 if isinstance(value, dict) and "pod_override" in value:
189 pod_override = value["pod_override"]
191 # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
192 if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
193 value["pod_override"] = BaseSerialization.deserialize(pod_override)
194 return value
196 return process
198 def compare_values(self, x, y):
199 """
200 The TaskInstance.executor_config attribute is a pickled object that may contain
201 kubernetes objects. If the installed library version has changed since the
202 object was originally pickled, due to the underlying ``__eq__`` method on these
203 objects (which converts them to JSON), we may encounter attribute errors. In this
204 case we should replace the stored object.
206 From https://github.com/apache/airflow/pull/24356 we use our serializer to store
207 k8s objects, but there could still be raw pickled k8s objects in the database,
208 stored from earlier version, so we still compare them defensively here.
209 """
210 if self.comparator:
211 return self.comparator(x, y)
212 else:
213 try:
214 return x == y
215 except AttributeError:
216 return False
219class Interval(TypeDecorator):
220 """Base class representing a time interval."""
222 impl = Text
224 cache_ok = True
226 attr_keys = {
227 datetime.timedelta: ("days", "seconds", "microseconds"),
228 relativedelta.relativedelta: (
229 "years",
230 "months",
231 "days",
232 "leapdays",
233 "hours",
234 "minutes",
235 "seconds",
236 "microseconds",
237 "year",
238 "month",
239 "day",
240 "hour",
241 "minute",
242 "second",
243 "microsecond",
244 ),
245 }
247 def process_bind_param(self, value, dialect):
248 if isinstance(value, tuple(self.attr_keys)):
249 attrs = {key: getattr(value, key) for key in self.attr_keys[type(value)]}
250 return json.dumps({"type": type(value).__name__, "attrs": attrs})
251 return json.dumps(value)
253 def process_result_value(self, value, dialect):
254 if not value:
255 return value
256 data = json.loads(value)
257 if isinstance(data, dict):
258 type_map = {key.__name__: key for key in self.attr_keys}
259 return type_map[data["type"]](**data["attrs"])
260 return data
263def skip_locked(session: Session) -> dict[str, Any]:
264 """
265 Return kargs for passing to `with_for_update()` suitable for the current DB engine version.
267 We do this as we document the fact that on DB engines that don't support this construct, we do not
268 support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still
269 work, just slightly slower in some circumstances.
271 Specifically don't emit SKIP LOCKED for MySQL < 8, or MariaDB, neither of which support this construct
273 See https://jira.mariadb.org/browse/MDEV-13115
274 """
275 dialect = session.bind.dialect
277 if dialect.name != "mysql" or dialect.supports_for_update_of:
278 return {"skip_locked": True}
279 else:
280 return {}
283def nowait(session: Session) -> dict[str, Any]:
284 """
285 Return kwargs for passing to `with_for_update()` suitable for the current DB engine version.
287 We do this as we document the fact that on DB engines that don't support this construct, we do not
288 support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still
289 work, just slightly slower in some circumstances.
291 Specifically don't emit NOWAIT for MySQL < 8, or MariaDB, neither of which support this construct
293 See https://jira.mariadb.org/browse/MDEV-13115
294 """
295 dialect = session.bind.dialect
297 if dialect.name != "mysql" or dialect.supports_for_update_of:
298 return {"nowait": True}
299 else:
300 return {}
303def nulls_first(col, session: Session) -> dict[str, Any]:
304 """
305 Adds a nullsfirst construct to the column ordering. Currently only Postgres supports it.
306 In MySQL & Sqlite NULL values are considered lower than any non-NULL value, therefore, NULL values
307 appear first when the order is ASC (ascending)
308 """
309 if session.bind.dialect.name == "postgresql":
310 return nullsfirst(col)
311 else:
312 return col
315USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True)
318def with_row_locks(query, session: Session, **kwargs):
319 """
320 Apply with_for_update to an SQLAlchemy query, if row level locking is in use.
322 :param query: An SQLAlchemy Query object
323 :param session: ORM Session
324 :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc)
325 :return: updated query
326 """
327 dialect = session.bind.dialect
329 # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it.
330 if USE_ROW_LEVEL_LOCKING and (dialect.name != "mysql" or dialect.supports_for_update_of):
331 return query.with_for_update(**kwargs)
332 else:
333 return query
336class CommitProhibitorGuard:
337 """Context manager class that powers prohibit_commit"""
339 expected_commit = False
341 def __init__(self, session: Session):
342 self.session = session
344 def _validate_commit(self, _):
345 if self.expected_commit:
346 self.expected_commit = False
347 return
348 raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!")
350 def __enter__(self):
351 event.listen(self.session, "before_commit", self._validate_commit)
352 return self
354 def __exit__(self, *exc_info):
355 event.remove(self.session, "before_commit", self._validate_commit)
357 def commit(self):
358 """
359 Commit the session.
361 This is the required way to commit when the guard is in scope
362 """
363 self.expected_commit = True
364 self.session.commit()
367def prohibit_commit(session):
368 """
369 Return a context manager that will disallow any commit that isn't done via the context manager.
371 The aim of this is to ensure that transaction lifetime is strictly controlled which is especially
372 important in the core scheduler loop. Any commit on the session that is _not_ via this context manager
373 will result in RuntimeError
375 Example usage:
377 .. code:: python
379 with prohibit_commit(session) as guard:
380 # ... do something with session
381 guard.commit()
383 # This would throw an error
384 # session.commit()
385 """
386 return CommitProhibitorGuard(session)
389def is_lock_not_available_error(error: OperationalError):
390 """Check if the Error is about not being able to acquire lock"""
391 # DB specific error codes:
392 # Postgres: 55P03
393 # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT
394 # is set.'
395 # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction
396 # (when NOWAIT isn't available)
397 db_err_code = getattr(error.orig, "pgcode", None) or error.orig.args[0]
399 # We could test if error.orig is an instance of
400 # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves
401 # importing it. This doesn't
402 if db_err_code in ("55P03", 1205, 3572):
403 return True
404 return False
407def tuple_in_condition(
408 columns: tuple[ColumnElement, ...],
409 collection: Iterable[Any],
410) -> ColumnOperators:
411 """Generates a tuple-in-collection operator to use in ``.filter()``.
413 For most SQL backends, this generates a simple ``([col, ...]) IN [condition]``
414 clause. This however does not work with MSSQL, where we need to expand to
415 ``(c1 = v1a AND c2 = v2a ...) OR (c1 = v1b AND c2 = v2b ...) ...`` manually.
417 :meta private:
418 """
419 if settings.engine.dialect.name != "mssql":
420 return tuple_(*columns).in_(collection)
421 clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in collection]
422 if not clauses:
423 return false()
424 return or_(*clauses)
427def tuple_not_in_condition(
428 columns: tuple[ColumnElement, ...],
429 collection: Iterable[Any],
430) -> ColumnOperators:
431 """Generates a tuple-not-in-collection operator to use in ``.filter()``.
433 This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
435 :meta private:
436 """
437 if settings.engine.dialect.name != "mssql":
438 return tuple_(*columns).not_in(collection)
439 clauses = [or_(*(c != v for c, v in zip(columns, values))) for values in collection]
440 if not clauses:
441 return true()
442 return and_(*clauses)