Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/utils/sqlalchemy.py: 30%
213 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import copy
21import datetime
22import json
23import logging
24from typing import TYPE_CHECKING, Any, Iterable
26import pendulum
27from dateutil import relativedelta
28from sqlalchemy import TIMESTAMP, PickleType, and_, event, false, nullsfirst, or_, true, tuple_
29from sqlalchemy.dialects import mssql, mysql
30from sqlalchemy.exc import OperationalError
31from sqlalchemy.orm.session import Session
32from sqlalchemy.sql import ColumnElement
33from sqlalchemy.sql.expression import ColumnOperators
34from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
36from airflow import settings
37from airflow.configuration import conf
38from airflow.serialization.enums import Encoding
40if TYPE_CHECKING:
41 from kubernetes.client.models.v1_pod import V1Pod
43log = logging.getLogger(__name__)
45utc = pendulum.tz.timezone("UTC")
47using_mysql = conf.get_mandatory_value("database", "sql_alchemy_conn").lower().startswith("mysql")
50class UtcDateTime(TypeDecorator):
51 """
52 Almost equivalent to :class:`~sqlalchemy.types.TIMESTAMP` with
53 ``timezone=True`` option, but it differs from that by:
54 - Never silently take naive :class:`~datetime.datetime`, instead it
55 always raise :exc:`ValueError` unless time zone aware value.
56 - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo`
57 is always converted to UTC.
58 - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`,
59 it never return naive :class:`~datetime.datetime`, but time zone
60 aware value, even with SQLite or MySQL.
61 - Always returns TIMESTAMP in UTC.
62 """
64 impl = TIMESTAMP(timezone=True)
66 cache_ok = True
68 def process_bind_param(self, value, dialect):
69 if value is not None:
70 if not isinstance(value, datetime.datetime):
71 raise TypeError("expected datetime.datetime, not " + repr(value))
72 elif value.tzinfo is None:
73 raise ValueError("naive datetime is disallowed")
74 # For mysql we should store timestamps as naive values
75 # Timestamp in MYSQL is not timezone aware. In MySQL 5.6
76 # timezone added at the end is ignored but in MySQL 5.7
77 # inserting timezone value fails with 'invalid-date'
78 # See https://issues.apache.org/jira/browse/AIRFLOW-7001
79 if using_mysql:
80 from airflow.utils.timezone import make_naive
82 return make_naive(value, timezone=utc)
83 return value.astimezone(utc)
84 return None
86 def process_result_value(self, value, dialect):
87 """
88 Processes DateTimes from the DB making sure it is always
89 returning UTC. Not using timezone.convert_to_utc as that
90 converts to configured TIMEZONE while the DB might be
91 running with some other setting. We assume UTC datetimes
92 in the database.
93 """
94 if value is not None:
95 if value.tzinfo is None:
96 value = value.replace(tzinfo=utc)
97 else:
98 value = value.astimezone(utc)
100 return value
102 def load_dialect_impl(self, dialect):
103 if dialect.name == "mssql":
104 return mssql.DATETIME2(precision=6)
105 elif dialect.name == "mysql":
106 return mysql.TIMESTAMP(fsp=6)
107 return super().load_dialect_impl(dialect)
110class ExtendedJSON(TypeDecorator):
111 """
112 A version of the JSON column that uses the Airflow extended JSON
113 serialization provided by airflow.serialization.
114 """
116 impl = Text
118 cache_ok = True
120 def db_supports_json(self):
121 """Checks if the database supports JSON (i.e. is NOT MSSQL)."""
122 return not conf.get("database", "sql_alchemy_conn").startswith("mssql")
124 def load_dialect_impl(self, dialect) -> TypeEngine:
125 if self.db_supports_json():
126 return dialect.type_descriptor(JSON)
127 return dialect.type_descriptor(UnicodeText)
129 def process_bind_param(self, value, dialect):
130 from airflow.serialization.serialized_objects import BaseSerialization
132 if value is None:
133 return None
135 # First, encode it into our custom JSON-targeted dict format
136 value = BaseSerialization.serialize(value)
138 # Then, if the database does not have native JSON support, encode it again as a string
139 if not self.db_supports_json():
140 value = json.dumps(value)
142 return value
144 def process_result_value(self, value, dialect):
145 from airflow.serialization.serialized_objects import BaseSerialization
147 if value is None:
148 return None
150 # Deserialize from a string first if needed
151 if not self.db_supports_json():
152 value = json.loads(value)
154 return BaseSerialization.deserialize(value)
157def sanitize_for_serialization(obj: V1Pod):
158 """
159 Convert pod to dict.... but *safely*.
161 When pod objects created with one k8s version are unpickled in a python
162 env with a more recent k8s version (in which the object attrs may have
163 changed) the unpickled obj may throw an error because the attr
164 expected on new obj may not be there on the unpickled obj.
166 This function still converts the pod to a dict; the only difference is
167 it populates missing attrs with None. You may compare with
168 https://github.com/kubernetes-client/python/blob/5a96bbcbe21a552cc1f9cda13e0522fafb0dbac8/kubernetes/client/api_client.py#L202
170 If obj is None, return None.
171 If obj is str, int, long, float, bool, return directly.
172 If obj is datetime.datetime, datetime.date
173 convert to string in iso8601 format.
174 If obj is list, sanitize each element in the list.
175 If obj is dict, return the dict.
176 If obj is OpenAPI model, return the properties dict.
178 :param obj: The data to serialize.
179 :return: The serialized form of data.
181 :meta private:
182 """
183 if obj is None:
184 return None
185 elif isinstance(obj, (float, bool, bytes, str, int)):
186 return obj
187 elif isinstance(obj, list):
188 return [sanitize_for_serialization(sub_obj) for sub_obj in obj]
189 elif isinstance(obj, tuple):
190 return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj)
191 elif isinstance(obj, (datetime.datetime, datetime.date)):
192 return obj.isoformat()
194 if isinstance(obj, dict):
195 obj_dict = obj
196 else:
197 obj_dict = {
198 obj.attribute_map[attr]: getattr(obj, attr)
199 for attr, _ in obj.openapi_types.items()
200 # below is the only line we change, and we just add default=None for getattr
201 if getattr(obj, attr, None) is not None
202 }
204 return {key: sanitize_for_serialization(val) for key, val in obj_dict.items()}
207def ensure_pod_is_valid_after_unpickling(pod: V1Pod) -> V1Pod | None:
208 """
209 Convert pod to json and back so that pod is safe.
211 The pod_override in executor_config is a V1Pod object.
212 Such objects created with one k8s version, when unpickled in
213 an env with upgraded k8s version, may blow up when
214 `to_dict` is called, because openapi client code gen calls
215 getattr on all attrs in openapi_types for each object, and when
216 new attrs are added to that list, getattr will fail.
218 Here we re-serialize it to ensure it is not going to blow up.
220 :meta private:
221 """
222 try:
223 # if to_dict works, the pod is fine
224 pod.to_dict()
225 return pod
226 except AttributeError:
227 pass
228 try:
229 from kubernetes.client.models.v1_pod import V1Pod
230 except ImportError:
231 return None
232 if not isinstance(pod, V1Pod):
233 return None
234 try:
235 from airflow.kubernetes.pod_generator import PodGenerator
237 # now we actually reserialize / deserialize the pod
238 pod_dict = sanitize_for_serialization(pod)
239 return PodGenerator.deserialize_model_dict(pod_dict)
240 except Exception:
241 return None
244class ExecutorConfigType(PickleType):
245 """
246 Adds special handling for K8s executor config. If we unpickle a k8s object that was
247 pickled under an earlier k8s library version, then the unpickled object may throw an error
248 when to_dict is called. To be more tolerant of version changes we convert to JSON using
249 Airflow's serializer before pickling.
250 """
252 cache_ok = True
254 def bind_processor(self, dialect):
256 from airflow.serialization.serialized_objects import BaseSerialization
258 super_process = super().bind_processor(dialect)
260 def process(value):
261 val_copy = copy.copy(value)
262 if isinstance(val_copy, dict) and "pod_override" in val_copy:
263 val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"])
264 return super_process(val_copy)
266 return process
268 def result_processor(self, dialect, coltype):
269 from airflow.serialization.serialized_objects import BaseSerialization
271 super_process = super().result_processor(dialect, coltype)
273 def process(value):
274 value = super_process(value) # unpickle
276 if isinstance(value, dict) and "pod_override" in value:
277 pod_override = value["pod_override"]
279 if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
280 # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
281 value["pod_override"] = BaseSerialization.deserialize(pod_override)
282 else:
283 # backcompat path
284 # we no longer pickle raw pods but this code may be reached
285 # when accessing executor configs created in a prior version
286 new_pod = ensure_pod_is_valid_after_unpickling(pod_override)
287 if new_pod:
288 value["pod_override"] = new_pod
289 return value
291 return process
293 def compare_values(self, x, y):
294 """
295 The TaskInstance.executor_config attribute is a pickled object that may contain
296 kubernetes objects. If the installed library version has changed since the
297 object was originally pickled, due to the underlying ``__eq__`` method on these
298 objects (which converts them to JSON), we may encounter attribute errors. In this
299 case we should replace the stored object.
301 From https://github.com/apache/airflow/pull/24356 we use our serializer to store
302 k8s objects, but there could still be raw pickled k8s objects in the database,
303 stored from earlier version, so we still compare them defensively here.
304 """
305 if self.comparator:
306 return self.comparator(x, y)
307 else:
308 try:
309 return x == y
310 except AttributeError:
311 return False
314class Interval(TypeDecorator):
315 """Base class representing a time interval."""
317 impl = Text
319 cache_ok = True
321 attr_keys = {
322 datetime.timedelta: ("days", "seconds", "microseconds"),
323 relativedelta.relativedelta: (
324 "years",
325 "months",
326 "days",
327 "leapdays",
328 "hours",
329 "minutes",
330 "seconds",
331 "microseconds",
332 "year",
333 "month",
334 "day",
335 "hour",
336 "minute",
337 "second",
338 "microsecond",
339 ),
340 }
342 def process_bind_param(self, value, dialect):
343 if isinstance(value, tuple(self.attr_keys)):
344 attrs = {key: getattr(value, key) for key in self.attr_keys[type(value)]}
345 return json.dumps({"type": type(value).__name__, "attrs": attrs})
346 return json.dumps(value)
348 def process_result_value(self, value, dialect):
349 if not value:
350 return value
351 data = json.loads(value)
352 if isinstance(data, dict):
353 type_map = {key.__name__: key for key in self.attr_keys}
354 return type_map[data["type"]](**data["attrs"])
355 return data
358def skip_locked(session: Session) -> dict[str, Any]:
359 """
360 Return kargs for passing to `with_for_update()` suitable for the current DB engine version.
362 We do this as we document the fact that on DB engines that don't support this construct, we do not
363 support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still
364 work, just slightly slower in some circumstances.
366 Specifically don't emit SKIP LOCKED for MySQL < 8, or MariaDB, neither of which support this construct
368 See https://jira.mariadb.org/browse/MDEV-13115
369 """
370 dialect = session.bind.dialect
372 if dialect.name != "mysql" or dialect.supports_for_update_of:
373 return {"skip_locked": True}
374 else:
375 return {}
378def nowait(session: Session) -> dict[str, Any]:
379 """
380 Return kwargs for passing to `with_for_update()` suitable for the current DB engine version.
382 We do this as we document the fact that on DB engines that don't support this construct, we do not
383 support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still
384 work, just slightly slower in some circumstances.
386 Specifically don't emit NOWAIT for MySQL < 8, or MariaDB, neither of which support this construct
388 See https://jira.mariadb.org/browse/MDEV-13115
389 """
390 dialect = session.bind.dialect
392 if dialect.name != "mysql" or dialect.supports_for_update_of:
393 return {"nowait": True}
394 else:
395 return {}
398def nulls_first(col, session: Session) -> dict[str, Any]:
399 """Specify *NULLS FIRST* to the column ordering.
401 This is only done to Postgres, currently the only backend that supports it.
402 Other databases do not need it since NULL values are considered lower than
403 any other values, and appear first when the order is ASC (ascending).
404 """
405 if session.bind.dialect.name == "postgresql":
406 return nullsfirst(col)
407 else:
408 return col
411USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True)
414def with_row_locks(query, session: Session, **kwargs):
415 """
416 Apply with_for_update to an SQLAlchemy query, if row level locking is in use.
418 :param query: An SQLAlchemy Query object
419 :param session: ORM Session
420 :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc)
421 :return: updated query
422 """
423 dialect = session.bind.dialect
425 # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it.
426 if USE_ROW_LEVEL_LOCKING and (dialect.name != "mysql" or dialect.supports_for_update_of):
427 return query.with_for_update(**kwargs)
428 else:
429 return query
432class CommitProhibitorGuard:
433 """Context manager class that powers prohibit_commit."""
435 expected_commit = False
437 def __init__(self, session: Session):
438 self.session = session
440 def _validate_commit(self, _):
441 if self.expected_commit:
442 self.expected_commit = False
443 return
444 raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!")
446 def __enter__(self):
447 event.listen(self.session, "before_commit", self._validate_commit)
448 return self
450 def __exit__(self, *exc_info):
451 event.remove(self.session, "before_commit", self._validate_commit)
453 def commit(self):
454 """
455 Commit the session.
457 This is the required way to commit when the guard is in scope
458 """
459 self.expected_commit = True
460 self.session.commit()
463def prohibit_commit(session):
464 """
465 Return a context manager that will disallow any commit that isn't done via the context manager.
467 The aim of this is to ensure that transaction lifetime is strictly controlled which is especially
468 important in the core scheduler loop. Any commit on the session that is _not_ via this context manager
469 will result in RuntimeError
471 Example usage:
473 .. code:: python
475 with prohibit_commit(session) as guard:
476 # ... do something with session
477 guard.commit()
479 # This would throw an error
480 # session.commit()
481 """
482 return CommitProhibitorGuard(session)
485def is_lock_not_available_error(error: OperationalError):
486 """Check if the Error is about not being able to acquire lock."""
487 # DB specific error codes:
488 # Postgres: 55P03
489 # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT
490 # is set.'
491 # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction
492 # (when NOWAIT isn't available)
493 db_err_code = getattr(error.orig, "pgcode", None) or error.orig.args[0]
495 # We could test if error.orig is an instance of
496 # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves
497 # importing it. This doesn't
498 if db_err_code in ("55P03", 1205, 3572):
499 return True
500 return False
503def tuple_in_condition(
504 columns: tuple[ColumnElement, ...],
505 collection: Iterable[Any],
506) -> ColumnOperators:
507 """Generates a tuple-in-collection operator to use in ``.filter()``.
509 For most SQL backends, this generates a simple ``([col, ...]) IN [condition]``
510 clause. This however does not work with MSSQL, where we need to expand to
511 ``(c1 = v1a AND c2 = v2a ...) OR (c1 = v1b AND c2 = v2b ...) ...`` manually.
513 :meta private:
514 """
515 if settings.engine.dialect.name != "mssql":
516 return tuple_(*columns).in_(collection)
517 clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in collection]
518 if not clauses:
519 return false()
520 return or_(*clauses)
523def tuple_not_in_condition(
524 columns: tuple[ColumnElement, ...],
525 collection: Iterable[Any],
526) -> ColumnOperators:
527 """Generates a tuple-not-in-collection operator to use in ``.filter()``.
529 This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
531 :meta private:
532 """
533 if settings.engine.dialect.name != "mssql":
534 return tuple_(*columns).not_in(collection)
535 clauses = [or_(*(c != v for c, v in zip(columns, values))) for values in collection]
536 if not clauses:
537 return true()
538 return and_(*clauses)