1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import contextlib
21import copy
22import datetime
23import logging
24from collections.abc import Generator
25from importlib import metadata
26from typing import TYPE_CHECKING, Any
27
28from packaging import version
29from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst
30from sqlalchemy.dialects import mysql
31from sqlalchemy.dialects.postgresql import JSONB
32from sqlalchemy.types import JSON, Text, TypeDecorator
33
34from airflow._shared.timezones.timezone import make_naive, utc
35from airflow.configuration import conf
36from airflow.serialization.enums import Encoding
37
38if TYPE_CHECKING:
39 from collections.abc import Iterable
40
41 from kubernetes.client.models.v1_pod import V1Pod
42 from sqlalchemy.exc import OperationalError
43 from sqlalchemy.orm import Session
44 from sqlalchemy.sql import Select
45 from sqlalchemy.sql.elements import ColumnElement
46 from sqlalchemy.types import TypeEngine
47
48 from airflow.typing_compat import Self
49
50
51log = logging.getLogger(__name__)
52
53try:
54 from sqlalchemy.orm import mapped_column
55except ImportError:
56 # fallback for SQLAlchemy < 2.0
57 def mapped_column(*args, **kwargs): # type: ignore[misc]
58 from sqlalchemy import Column
59
60 return Column(*args, **kwargs)
61
62
63def get_dialect_name(session: Session) -> str | None:
64 """Safely get the name of the dialect associated with the given session."""
65 if (bind := session.get_bind()) is None:
66 raise ValueError("No bind/engine is associated with the provided Session")
67 return getattr(bind.dialect, "name", None)
68
69
70class UtcDateTime(TypeDecorator):
71 """
72 Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` option, with some differences.
73
74 - Never silently take naive :class:`~datetime.datetime`, instead it
75 always raise :exc:`ValueError` unless time zone aware value.
76 - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo`
77 is always converted to UTC.
78 - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`,
79 it never return naive :class:`~datetime.datetime`, but time zone
80 aware value, even with SQLite or MySQL.
81 - Always returns TIMESTAMP in UTC.
82 """
83
84 impl = TIMESTAMP(timezone=True)
85
86 cache_ok = True
87
88 def process_bind_param(self, value, dialect):
89 if not isinstance(value, datetime.datetime):
90 if value is None:
91 return None
92 raise TypeError(f"expected datetime.datetime, not {value!r}")
93 if value.tzinfo is None:
94 raise ValueError("naive datetime is disallowed")
95 if dialect.name == "mysql":
96 # For mysql versions prior 8.0.19 we should send timestamps as naive values in UTC
97 # see: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-literals.html
98 return make_naive(value, timezone=utc)
99 return value.astimezone(utc)
100
101 def process_result_value(self, value, dialect):
102 """
103 Process DateTimes from the DB making sure to always return UTC.
104
105 Not using timezone.convert_to_utc as that converts to configured TIMEZONE
106 while the DB might be running with some other setting. We assume UTC
107 datetimes in the database.
108 """
109 if value is not None:
110 if value.tzinfo is None:
111 value = value.replace(tzinfo=utc)
112 else:
113 value = value.astimezone(utc)
114
115 return value
116
117 def load_dialect_impl(self, dialect):
118 if dialect.name == "mysql":
119 return mysql.TIMESTAMP(fsp=6)
120 return super().load_dialect_impl(dialect)
121
122
123class ExtendedJSON(TypeDecorator):
124 """
125 A version of the JSON column that uses the Airflow extended JSON serialization.
126
127 See airflow.serialization.
128 """
129
130 impl = Text
131
132 cache_ok = True
133
134 should_evaluate_none = True
135
136 def load_dialect_impl(self, dialect) -> TypeEngine:
137 if dialect.name == "postgresql":
138 return dialect.type_descriptor(JSONB)
139 return dialect.type_descriptor(JSON)
140
141 def process_bind_param(self, value, dialect):
142 from airflow.serialization.serialized_objects import BaseSerialization
143
144 if value is None:
145 return None
146
147 return BaseSerialization.serialize(value)
148
149 def process_result_value(self, value, dialect):
150 from airflow.serialization.serialized_objects import BaseSerialization
151
152 if value is None:
153 return None
154
155 return BaseSerialization.deserialize(value)
156
157
158def sanitize_for_serialization(obj: V1Pod):
159 """
160 Convert pod to dict.... but *safely*.
161
162 When pod objects created with one k8s version are unpickled in a python
163 env with a more recent k8s version (in which the object attrs may have
164 changed) the unpickled obj may throw an error because the attr
165 expected on new obj may not be there on the unpickled obj.
166
167 This function still converts the pod to a dict; the only difference is
168 it populates missing attrs with None. You may compare with
169 https://github.com/kubernetes-client/python/blob/5a96bbcbe21a552cc1f9cda13e0522fafb0dbac8/kubernetes/client/api_client.py#L202
170
171 If obj is None, return None.
172 If obj is str, int, long, float, bool, return directly.
173 If obj is datetime.datetime, datetime.date
174 convert to string in iso8601 format.
175 If obj is list, sanitize each element in the list.
176 If obj is dict, return the dict.
177 If obj is OpenAPI model, return the properties dict.
178
179 :param obj: The data to serialize.
180 :return: The serialized form of data.
181
182 :meta private:
183 """
184 if obj is None:
185 return None
186 if isinstance(obj, (float, bool, bytes, str, int)):
187 return obj
188 if isinstance(obj, list):
189 return [sanitize_for_serialization(sub_obj) for sub_obj in obj]
190 if isinstance(obj, tuple):
191 return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj)
192 if isinstance(obj, (datetime.datetime, datetime.date)):
193 return obj.isoformat()
194
195 if isinstance(obj, dict):
196 obj_dict = obj
197 else:
198 obj_dict = {
199 obj.attribute_map[attr]: getattr(obj, attr)
200 for attr, _ in obj.openapi_types.items()
201 # below is the only line we change, and we just add default=None for getattr
202 if getattr(obj, attr, None) is not None
203 }
204
205 return {key: sanitize_for_serialization(val) for key, val in obj_dict.items()}
206
207
208def ensure_pod_is_valid_after_unpickling(pod: V1Pod) -> V1Pod | None:
209 """
210 Convert pod to json and back so that pod is safe.
211
212 The pod_override in executor_config is a V1Pod object.
213 Such objects created with one k8s version, when unpickled in
214 an env with upgraded k8s version, may blow up when
215 `to_dict` is called, because openapi client code gen calls
216 getattr on all attrs in openapi_types for each object, and when
217 new attrs are added to that list, getattr will fail.
218
219 Here we re-serialize it to ensure it is not going to blow up.
220
221 :meta private:
222 """
223 try:
224 # if to_dict works, the pod is fine
225 pod.to_dict()
226 return pod
227 except AttributeError:
228 pass
229 try:
230 from kubernetes.client.models.v1_pod import V1Pod
231 except ImportError:
232 return None
233 if not isinstance(pod, V1Pod):
234 return None
235 try:
236 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
237
238 # now we actually reserialize / deserialize the pod
239 pod_dict = sanitize_for_serialization(pod)
240 return PodGenerator.deserialize_model_dict(pod_dict)
241 except Exception:
242 return None
243
244
245class ExecutorConfigType(PickleType):
246 """
247 Adds special handling for K8s executor config.
248
249 If we unpickle a k8s object that was pickled under an earlier k8s library version, then
250 the unpickled object may throw an error when to_dict is called. To be more tolerant of
251 version changes we convert to JSON using Airflow's serializer before pickling.
252 """
253
254 cache_ok = True
255
256 def bind_processor(self, dialect):
257 from airflow.serialization.serialized_objects import BaseSerialization
258
259 super_process = super().bind_processor(dialect)
260
261 def process(value):
262 val_copy = copy.copy(value)
263 if isinstance(val_copy, dict) and "pod_override" in val_copy:
264 val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"])
265 return super_process(val_copy)
266
267 return process
268
269 def result_processor(self, dialect, coltype):
270 from airflow.serialization.serialized_objects import BaseSerialization
271
272 super_process = super().result_processor(dialect, coltype)
273
274 def process(value):
275 value = super_process(value) # unpickle
276
277 if isinstance(value, dict) and "pod_override" in value:
278 pod_override = value["pod_override"]
279
280 if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
281 # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
282 value["pod_override"] = BaseSerialization.deserialize(pod_override)
283 else:
284 # backcompat path
285 # we no longer pickle raw pods but this code may be reached
286 # when accessing executor configs created in a prior version
287 new_pod = ensure_pod_is_valid_after_unpickling(pod_override)
288 if new_pod:
289 value["pod_override"] = new_pod
290 return value
291
292 return process
293
294 def compare_values(self, x, y):
295 """
296 Compare x and y using self.comparator if available. Else, use __eq__.
297
298 The TaskInstance.executor_config attribute is a pickled object that may contain kubernetes objects.
299
300 If the installed library version has changed since the object was originally pickled,
301 due to the underlying ``__eq__`` method on these objects (which converts them to JSON),
302 we may encounter attribute errors. In this case we should replace the stored object.
303
304 From https://github.com/apache/airflow/pull/24356 we use our serializer to store
305 k8s objects, but there could still be raw pickled k8s objects in the database,
306 stored from earlier version, so we still compare them defensively here.
307 """
308 if self.comparator:
309 return self.comparator(x, y)
310 try:
311 return x == y
312 except AttributeError:
313 return False
314
315
316def nulls_first(col: ColumnElement, session: Session) -> ColumnElement:
317 """
318 Specify *NULLS FIRST* to the column ordering.
319
320 This is only done to Postgres, currently the only backend that supports it.
321 Other databases do not need it since NULL values are considered lower than
322 any other values, and appear first when the order is ASC (ascending).
323 """
324 if get_dialect_name(session) == "postgresql":
325 return nullsfirst(col)
326 return col
327
328
329USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True)
330
331
332def with_row_locks(
333 query: Select[Any],
334 session: Session,
335 *,
336 nowait: bool = False,
337 skip_locked: bool = False,
338 key_share: bool = True,
339 **kwargs,
340) -> Select[Any]:
341 """
342 Apply with_for_update to the SQLAlchemy query if row level locking is in use.
343
344 This wrapper is needed so we don't use the syntax on unsupported database
345 engines. In particular, MySQL (prior to 8.0) and MariaDB do not support
346 row locking, where we do not support nor recommend running HA scheduler. If
347 a user ignores this and tries anyway, everything will still work, just
348 slightly slower in some circumstances.
349
350 See https://jira.mariadb.org/browse/MDEV-13115
351
352 :param query: An SQLAlchemy Query object
353 :param session: ORM Session
354 :param nowait: If set to True, will pass NOWAIT to supported database backends.
355 :param skip_locked: If set to True, will pass SKIP LOCKED to supported database backends.
356 :param key_share: If true, will lock with FOR KEY SHARE UPDATE (at least on postgres).
357 :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc)
358 :return: updated query
359 """
360 try:
361 dialect_name = get_dialect_name(session)
362 except ValueError:
363 return query
364 if not dialect_name:
365 return query
366
367 # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it.
368 if not USE_ROW_LEVEL_LOCKING:
369 return query
370 if dialect_name == "mysql" and not getattr(
371 session.bind.dialect if session.bind else None, "supports_for_update_of", False
372 ):
373 return query
374 if nowait:
375 kwargs["nowait"] = True
376 if skip_locked:
377 kwargs["skip_locked"] = True
378 if key_share:
379 kwargs["key_share"] = True
380 return query.with_for_update(**kwargs)
381
382
383@contextlib.contextmanager
384def lock_rows(query: Select, session: Session) -> Generator[None, None, None]:
385 """
386 Lock database rows during the context manager block.
387
388 This is a convenient method for ``with_row_locks`` when we don't need the
389 locked rows.
390
391 :meta private:
392 """
393 locked_rows = with_row_locks(query, session)
394 yield
395 del locked_rows
396
397
398class CommitProhibitorGuard:
399 """Context manager class that powers prohibit_commit."""
400
401 expected_commit = False
402
403 def __init__(self, session: Session):
404 self.session = session
405
406 def _validate_commit(self, _):
407 if self.expected_commit:
408 self.expected_commit = False
409 return
410 raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!")
411
412 def __enter__(self) -> Self:
413 event.listen(self.session, "before_commit", self._validate_commit)
414 return self
415
416 def __exit__(self, *exc_info):
417 event.remove(self.session, "before_commit", self._validate_commit)
418
419 def commit(self):
420 """
421 Commit the session.
422
423 This is the required way to commit when the guard is in scope
424 """
425 self.expected_commit = True
426 self.session.commit()
427
428
429def prohibit_commit(session):
430 """
431 Return a context manager that will disallow any commit that isn't done via the context manager.
432
433 The aim of this is to ensure that transaction lifetime is strictly controlled which is especially
434 important in the core scheduler loop. Any commit on the session that is _not_ via this context manager
435 will result in RuntimeError
436
437 Example usage:
438
439 .. code:: python
440
441 with prohibit_commit(session) as guard:
442 # ... do something with session
443 guard.commit()
444
445 # This would throw an error
446 # session.commit()
447 """
448 return CommitProhibitorGuard(session)
449
450
451def is_lock_not_available_error(error: OperationalError):
452 """Check if the Error is about not being able to acquire lock."""
453 # DB specific error codes:
454 # Postgres: 55P03
455 # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT
456 # is set.'
457 # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction
458 # (when NOWAIT isn't available)
459 db_err_code = getattr(error.orig, "pgcode", None) or (
460 error.orig.args[0] if error.orig and error.orig.args else None
461 )
462
463 # We could test if error.orig is an instance of
464 # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves
465 # importing it. This doesn't
466 if db_err_code in ("55P03", 1205, 3572):
467 return True
468 return False
469
470
471def get_orm_mapper():
472 """Get the correct ORM mapper for the installed SQLAlchemy version."""
473 import sqlalchemy.orm.mapper
474
475 return sqlalchemy.orm.mapper if is_sqlalchemy_v1() else sqlalchemy.orm.Mapper
476
477
478def is_sqlalchemy_v1() -> bool:
479 return version.parse(metadata.version("sqlalchemy")).major == 1
480
481
482def make_dialect_kwarg(dialect: str) -> dict[str, str | Iterable[str]]:
483 """Create an SQLAlchemy-version-aware dialect keyword argument."""
484 return {"dialect_name": dialect} if is_sqlalchemy_v1() else {"dialect_names": (dialect,)}