Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/utils/sqlalchemy.py: 34%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import contextlib
21import copy
22import datetime
23import json
24import logging
25from importlib import metadata
26from typing import TYPE_CHECKING, Any, Generator, Iterable, overload
28from dateutil import relativedelta
29from packaging import version
30from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_
31from sqlalchemy.dialects import mysql
32from sqlalchemy.types import JSON, Text, TypeDecorator
34from airflow.configuration import conf
35from airflow.serialization.enums import Encoding
36from airflow.utils.timezone import make_naive, utc
38if TYPE_CHECKING:
39 from kubernetes.client.models.v1_pod import V1Pod
40 from sqlalchemy.exc import OperationalError
41 from sqlalchemy.orm import Query, Session
42 from sqlalchemy.sql import ColumnElement, Select
43 from sqlalchemy.sql.expression import ColumnOperators
44 from sqlalchemy.types import TypeEngine
46log = logging.getLogger(__name__)
49class UtcDateTime(TypeDecorator):
50 """
51 Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` option, with some differences.
53 - Never silently take naive :class:`~datetime.datetime`, instead it
54 always raise :exc:`ValueError` unless time zone aware value.
55 - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo`
56 is always converted to UTC.
57 - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`,
58 it never return naive :class:`~datetime.datetime`, but time zone
59 aware value, even with SQLite or MySQL.
60 - Always returns TIMESTAMP in UTC.
61 """
63 impl = TIMESTAMP(timezone=True)
65 cache_ok = True
67 def process_bind_param(self, value, dialect):
68 if not isinstance(value, datetime.datetime):
69 if value is None:
70 return None
71 raise TypeError(f"expected datetime.datetime, not {value!r}")
72 elif value.tzinfo is None:
73 raise ValueError("naive datetime is disallowed")
74 elif dialect.name == "mysql":
75 # For mysql versions prior 8.0.19 we should send timestamps as naive values in UTC
76 # see: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-literals.html
77 return make_naive(value, timezone=utc)
78 return value.astimezone(utc)
80 def process_result_value(self, value, dialect):
81 """
82 Process DateTimes from the DB making sure to always return UTC.
84 Not using timezone.convert_to_utc as that converts to configured TIMEZONE
85 while the DB might be running with some other setting. We assume UTC
86 datetimes in the database.
87 """
88 if value is not None:
89 if value.tzinfo is None:
90 value = value.replace(tzinfo=utc)
91 else:
92 value = value.astimezone(utc)
94 return value
96 def load_dialect_impl(self, dialect):
97 if dialect.name == "mysql":
98 return mysql.TIMESTAMP(fsp=6)
99 return super().load_dialect_impl(dialect)
102class ExtendedJSON(TypeDecorator):
103 """
104 A version of the JSON column that uses the Airflow extended JSON serialization.
106 See airflow.serialization.
107 """
109 impl = Text
111 cache_ok = True
113 def load_dialect_impl(self, dialect) -> TypeEngine:
114 return dialect.type_descriptor(JSON)
116 def process_bind_param(self, value, dialect):
117 from airflow.serialization.serialized_objects import BaseSerialization
119 if value is None:
120 return None
122 return BaseSerialization.serialize(value)
124 def process_result_value(self, value, dialect):
125 from airflow.serialization.serialized_objects import BaseSerialization
127 if value is None:
128 return None
130 return BaseSerialization.deserialize(value)
133def sanitize_for_serialization(obj: V1Pod):
134 """
135 Convert pod to dict.... but *safely*.
137 When pod objects created with one k8s version are unpickled in a python
138 env with a more recent k8s version (in which the object attrs may have
139 changed) the unpickled obj may throw an error because the attr
140 expected on new obj may not be there on the unpickled obj.
142 This function still converts the pod to a dict; the only difference is
143 it populates missing attrs with None. You may compare with
144 https://github.com/kubernetes-client/python/blob/5a96bbcbe21a552cc1f9cda13e0522fafb0dbac8/kubernetes/client/api_client.py#L202
146 If obj is None, return None.
147 If obj is str, int, long, float, bool, return directly.
148 If obj is datetime.datetime, datetime.date
149 convert to string in iso8601 format.
150 If obj is list, sanitize each element in the list.
151 If obj is dict, return the dict.
152 If obj is OpenAPI model, return the properties dict.
154 :param obj: The data to serialize.
155 :return: The serialized form of data.
157 :meta private:
158 """
159 if obj is None:
160 return None
161 elif isinstance(obj, (float, bool, bytes, str, int)):
162 return obj
163 elif isinstance(obj, list):
164 return [sanitize_for_serialization(sub_obj) for sub_obj in obj]
165 elif isinstance(obj, tuple):
166 return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj)
167 elif isinstance(obj, (datetime.datetime, datetime.date)):
168 return obj.isoformat()
170 if isinstance(obj, dict):
171 obj_dict = obj
172 else:
173 obj_dict = {
174 obj.attribute_map[attr]: getattr(obj, attr)
175 for attr, _ in obj.openapi_types.items()
176 # below is the only line we change, and we just add default=None for getattr
177 if getattr(obj, attr, None) is not None
178 }
180 return {key: sanitize_for_serialization(val) for key, val in obj_dict.items()}
183def ensure_pod_is_valid_after_unpickling(pod: V1Pod) -> V1Pod | None:
184 """
185 Convert pod to json and back so that pod is safe.
187 The pod_override in executor_config is a V1Pod object.
188 Such objects created with one k8s version, when unpickled in
189 an env with upgraded k8s version, may blow up when
190 `to_dict` is called, because openapi client code gen calls
191 getattr on all attrs in openapi_types for each object, and when
192 new attrs are added to that list, getattr will fail.
194 Here we re-serialize it to ensure it is not going to blow up.
196 :meta private:
197 """
198 try:
199 # if to_dict works, the pod is fine
200 pod.to_dict()
201 return pod
202 except AttributeError:
203 pass
204 try:
205 from kubernetes.client.models.v1_pod import V1Pod
206 except ImportError:
207 return None
208 if not isinstance(pod, V1Pod):
209 return None
210 try:
211 try:
212 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
213 except ImportError:
214 from airflow.kubernetes.pre_7_4_0_compatibility.pod_generator import ( # type: ignore[assignment]
215 PodGenerator,
216 )
217 # now we actually reserialize / deserialize the pod
218 pod_dict = sanitize_for_serialization(pod)
219 return PodGenerator.deserialize_model_dict(pod_dict)
220 except Exception:
221 return None
224class ExecutorConfigType(PickleType):
225 """
226 Adds special handling for K8s executor config.
228 If we unpickle a k8s object that was pickled under an earlier k8s library version, then
229 the unpickled object may throw an error when to_dict is called. To be more tolerant of
230 version changes we convert to JSON using Airflow's serializer before pickling.
231 """
233 cache_ok = True
235 def bind_processor(self, dialect):
236 from airflow.serialization.serialized_objects import BaseSerialization
238 super_process = super().bind_processor(dialect)
240 def process(value):
241 val_copy = copy.copy(value)
242 if isinstance(val_copy, dict) and "pod_override" in val_copy:
243 val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"])
244 return super_process(val_copy)
246 return process
248 def result_processor(self, dialect, coltype):
249 from airflow.serialization.serialized_objects import BaseSerialization
251 super_process = super().result_processor(dialect, coltype)
253 def process(value):
254 value = super_process(value) # unpickle
256 if isinstance(value, dict) and "pod_override" in value:
257 pod_override = value["pod_override"]
259 if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
260 # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
261 value["pod_override"] = BaseSerialization.deserialize(pod_override)
262 else:
263 # backcompat path
264 # we no longer pickle raw pods but this code may be reached
265 # when accessing executor configs created in a prior version
266 new_pod = ensure_pod_is_valid_after_unpickling(pod_override)
267 if new_pod:
268 value["pod_override"] = new_pod
269 return value
271 return process
273 def compare_values(self, x, y):
274 """
275 Compare x and y using self.comparator if available. Else, use __eq__.
277 The TaskInstance.executor_config attribute is a pickled object that may contain kubernetes objects.
279 If the installed library version has changed since the object was originally pickled,
280 due to the underlying ``__eq__`` method on these objects (which converts them to JSON),
281 we may encounter attribute errors. In this case we should replace the stored object.
283 From https://github.com/apache/airflow/pull/24356 we use our serializer to store
284 k8s objects, but there could still be raw pickled k8s objects in the database,
285 stored from earlier version, so we still compare them defensively here.
286 """
287 if self.comparator:
288 return self.comparator(x, y)
289 else:
290 try:
291 return x == y
292 except AttributeError:
293 return False
296class Interval(TypeDecorator):
297 """Base class representing a time interval."""
299 impl = Text
301 cache_ok = True
303 attr_keys = {
304 datetime.timedelta: ("days", "seconds", "microseconds"),
305 relativedelta.relativedelta: (
306 "years",
307 "months",
308 "days",
309 "leapdays",
310 "hours",
311 "minutes",
312 "seconds",
313 "microseconds",
314 "year",
315 "month",
316 "day",
317 "hour",
318 "minute",
319 "second",
320 "microsecond",
321 ),
322 }
324 def process_bind_param(self, value, dialect):
325 if isinstance(value, tuple(self.attr_keys)):
326 attrs = {key: getattr(value, key) for key in self.attr_keys[type(value)]}
327 return json.dumps({"type": type(value).__name__, "attrs": attrs})
328 return json.dumps(value)
330 def process_result_value(self, value, dialect):
331 if not value:
332 return value
333 data = json.loads(value)
334 if isinstance(data, dict):
335 type_map = {key.__name__: key for key in self.attr_keys}
336 return type_map[data["type"]](**data["attrs"])
337 return data
340def nulls_first(col, session: Session) -> dict[str, Any]:
341 """Specify *NULLS FIRST* to the column ordering.
343 This is only done to Postgres, currently the only backend that supports it.
344 Other databases do not need it since NULL values are considered lower than
345 any other values, and appear first when the order is ASC (ascending).
346 """
347 if session.bind.dialect.name == "postgresql":
348 return nullsfirst(col)
349 else:
350 return col
353USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True)
356def with_row_locks(
357 query: Query,
358 session: Session,
359 *,
360 nowait: bool = False,
361 skip_locked: bool = False,
362 **kwargs,
363) -> Query:
364 """
365 Apply with_for_update to the SQLAlchemy query if row level locking is in use.
367 This wrapper is needed so we don't use the syntax on unsupported database
368 engines. In particular, MySQL (prior to 8.0) and MariaDB do not support
369 row locking, where we do not support nor recommend running HA scheduler. If
370 a user ignores this and tries anyway, everything will still work, just
371 slightly slower in some circumstances.
373 See https://jira.mariadb.org/browse/MDEV-13115
375 :param query: An SQLAlchemy Query object
376 :param session: ORM Session
377 :param nowait: If set to True, will pass NOWAIT to supported database backends.
378 :param skip_locked: If set to True, will pass SKIP LOCKED to supported database backends.
379 :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc)
380 :return: updated query
381 """
382 dialect = session.bind.dialect
384 # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it.
385 if not USE_ROW_LEVEL_LOCKING:
386 return query
387 if dialect.name == "mysql" and not dialect.supports_for_update_of:
388 return query
389 if nowait:
390 kwargs["nowait"] = True
391 if skip_locked:
392 kwargs["skip_locked"] = True
393 return query.with_for_update(**kwargs)
396@contextlib.contextmanager
397def lock_rows(query: Query, session: Session) -> Generator[None, None, None]:
398 """Lock database rows during the context manager block.
400 This is a convenient method for ``with_row_locks`` when we don't need the
401 locked rows.
403 :meta private:
404 """
405 locked_rows = with_row_locks(query, session)
406 yield
407 del locked_rows
410class CommitProhibitorGuard:
411 """Context manager class that powers prohibit_commit."""
413 expected_commit = False
415 def __init__(self, session: Session):
416 self.session = session
418 def _validate_commit(self, _):
419 if self.expected_commit:
420 self.expected_commit = False
421 return
422 raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!")
424 def __enter__(self):
425 event.listen(self.session, "before_commit", self._validate_commit)
426 return self
428 def __exit__(self, *exc_info):
429 event.remove(self.session, "before_commit", self._validate_commit)
431 def commit(self):
432 """
433 Commit the session.
435 This is the required way to commit when the guard is in scope
436 """
437 self.expected_commit = True
438 self.session.commit()
441def prohibit_commit(session):
442 """
443 Return a context manager that will disallow any commit that isn't done via the context manager.
445 The aim of this is to ensure that transaction lifetime is strictly controlled which is especially
446 important in the core scheduler loop. Any commit on the session that is _not_ via this context manager
447 will result in RuntimeError
449 Example usage:
451 .. code:: python
453 with prohibit_commit(session) as guard:
454 # ... do something with session
455 guard.commit()
457 # This would throw an error
458 # session.commit()
459 """
460 return CommitProhibitorGuard(session)
463def is_lock_not_available_error(error: OperationalError):
464 """Check if the Error is about not being able to acquire lock."""
465 # DB specific error codes:
466 # Postgres: 55P03
467 # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT
468 # is set.'
469 # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction
470 # (when NOWAIT isn't available)
471 db_err_code = getattr(error.orig, "pgcode", None) or error.orig.args[0]
473 # We could test if error.orig is an instance of
474 # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves
475 # importing it. This doesn't
476 if db_err_code in ("55P03", 1205, 3572):
477 return True
478 return False
481@overload
482def tuple_in_condition(
483 columns: tuple[ColumnElement, ...],
484 collection: Iterable[Any],
485) -> ColumnOperators: ...
488@overload
489def tuple_in_condition(
490 columns: tuple[ColumnElement, ...],
491 collection: Select,
492 *,
493 session: Session,
494) -> ColumnOperators: ...
497def tuple_in_condition(
498 columns: tuple[ColumnElement, ...],
499 collection: Iterable[Any] | Select,
500 *,
501 session: Session | None = None,
502) -> ColumnOperators:
503 """
504 Generate a tuple-in-collection operator to use in ``.where()``.
506 For most SQL backends, this generates a simple ``([col, ...]) IN [condition]``
507 clause.
509 :meta private:
510 """
511 return tuple_(*columns).in_(collection)
514@overload
515def tuple_not_in_condition(
516 columns: tuple[ColumnElement, ...],
517 collection: Iterable[Any],
518) -> ColumnOperators: ...
521@overload
522def tuple_not_in_condition(
523 columns: tuple[ColumnElement, ...],
524 collection: Select,
525 *,
526 session: Session,
527) -> ColumnOperators: ...
530def tuple_not_in_condition(
531 columns: tuple[ColumnElement, ...],
532 collection: Iterable[Any] | Select,
533 *,
534 session: Session | None = None,
535) -> ColumnOperators:
536 """
537 Generate a tuple-not-in-collection operator to use in ``.where()``.
539 This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
541 :meta private:
542 """
543 return tuple_(*columns).not_in(collection)
546def get_orm_mapper():
547 """Get the correct ORM mapper for the installed SQLAlchemy version."""
548 import sqlalchemy.orm.mapper
550 return sqlalchemy.orm.mapper if is_sqlalchemy_v1() else sqlalchemy.orm.Mapper
553def is_sqlalchemy_v1() -> bool:
554 return version.parse(metadata.version("sqlalchemy")).major == 1