1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import contextlib
21import copy
22import datetime
23import logging
24from collections.abc import Generator
25from typing import TYPE_CHECKING
26
27from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst
28from sqlalchemy.dialects import mysql
29from sqlalchemy.dialects.postgresql import JSONB
30from sqlalchemy.types import JSON, Text, TypeDecorator
31
32from airflow._shared.timezones.timezone import make_naive, utc
33from airflow.configuration import conf
34from airflow.serialization.enums import Encoding
35
36if TYPE_CHECKING:
37 from collections.abc import Iterable
38
39 from kubernetes.client.models.v1_pod import V1Pod
40 from sqlalchemy.exc import OperationalError
41 from sqlalchemy.orm import Session
42 from sqlalchemy.sql import Select
43 from sqlalchemy.sql.elements import ColumnElement
44 from sqlalchemy.types import TypeEngine
45
46 from airflow.typing_compat import Self
47
48
49log = logging.getLogger(__name__)
50
51
52def get_dialect_name(session: Session) -> str | None:
53 """Safely get the name of the dialect associated with the given session."""
54 if (bind := session.get_bind()) is None:
55 raise ValueError("No bind/engine is associated with the provided Session")
56 return getattr(bind.dialect, "name", None)
57
58
59class UtcDateTime(TypeDecorator):
60 """
61 Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` option, with some differences.
62
63 - Never silently take naive :class:`~datetime.datetime`, instead it
64 always raise :exc:`ValueError` unless time zone aware value.
65 - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo`
66 is always converted to UTC.
67 - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`,
68 it never return naive :class:`~datetime.datetime`, but time zone
69 aware value, even with SQLite or MySQL.
70 - Always returns TIMESTAMP in UTC.
71 """
72
73 impl = TIMESTAMP(timezone=True)
74
75 cache_ok = True
76
77 def process_bind_param(self, value, dialect):
78 if not isinstance(value, datetime.datetime):
79 if value is None:
80 return None
81 raise TypeError(f"expected datetime.datetime, not {value!r}")
82 if value.tzinfo is None:
83 raise ValueError("naive datetime is disallowed")
84 if dialect.name == "mysql":
85 # For mysql versions prior 8.0.19 we should send timestamps as naive values in UTC
86 # see: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-literals.html
87 return make_naive(value, timezone=utc)
88 return value.astimezone(utc)
89
90 def process_result_value(self, value, dialect):
91 """
92 Process DateTimes from the DB making sure to always return UTC.
93
94 Not using timezone.convert_to_utc as that converts to configured TIMEZONE
95 while the DB might be running with some other setting. We assume UTC
96 datetimes in the database.
97 """
98 if value is not None:
99 if value.tzinfo is None:
100 value = value.replace(tzinfo=utc)
101 else:
102 value = value.astimezone(utc)
103
104 return value
105
106 def load_dialect_impl(self, dialect):
107 if dialect.name == "mysql":
108 return mysql.TIMESTAMP(fsp=6)
109 return super().load_dialect_impl(dialect)
110
111
112class ExtendedJSON(TypeDecorator):
113 """
114 A version of the JSON column that uses the Airflow extended JSON serialization.
115
116 See airflow.serialization.
117 """
118
119 impl = Text
120
121 cache_ok = True
122
123 should_evaluate_none = True
124
125 def load_dialect_impl(self, dialect) -> TypeEngine:
126 if dialect.name == "postgresql":
127 return dialect.type_descriptor(JSONB)
128 return dialect.type_descriptor(JSON)
129
130 def process_bind_param(self, value, dialect):
131 from airflow.serialization.serialized_objects import BaseSerialization
132
133 if value is None:
134 return None
135
136 return BaseSerialization.serialize(value)
137
138 def process_result_value(self, value, dialect):
139 from airflow.serialization.serialized_objects import BaseSerialization
140
141 if value is None:
142 return None
143
144 return BaseSerialization.deserialize(value)
145
146
147def sanitize_for_serialization(obj: V1Pod):
148 """
149 Convert pod to dict.... but *safely*.
150
151 When pod objects created with one k8s version are unpickled in a python
152 env with a more recent k8s version (in which the object attrs may have
153 changed) the unpickled obj may throw an error because the attr
154 expected on new obj may not be there on the unpickled obj.
155
156 This function still converts the pod to a dict; the only difference is
157 it populates missing attrs with None. You may compare with
158 https://github.com/kubernetes-client/python/blob/5a96bbcbe21a552cc1f9cda13e0522fafb0dbac8/kubernetes/client/api_client.py#L202
159
160 If obj is None, return None.
161 If obj is str, int, long, float, bool, return directly.
162 If obj is datetime.datetime, datetime.date
163 convert to string in iso8601 format.
164 If obj is list, sanitize each element in the list.
165 If obj is dict, return the dict.
166 If obj is OpenAPI model, return the properties dict.
167
168 :param obj: The data to serialize.
169 :return: The serialized form of data.
170
171 :meta private:
172 """
173 if obj is None:
174 return None
175 if isinstance(obj, (float, bool, bytes, str, int)):
176 return obj
177 if isinstance(obj, list):
178 return [sanitize_for_serialization(sub_obj) for sub_obj in obj]
179 if isinstance(obj, tuple):
180 return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj)
181 if isinstance(obj, (datetime.datetime, datetime.date)):
182 return obj.isoformat()
183
184 if isinstance(obj, dict):
185 obj_dict = obj
186 else:
187 obj_dict = {
188 obj.attribute_map[attr]: getattr(obj, attr)
189 for attr, _ in obj.openapi_types.items()
190 # below is the only line we change, and we just add default=None for getattr
191 if getattr(obj, attr, None) is not None
192 }
193
194 return {key: sanitize_for_serialization(val) for key, val in obj_dict.items()}
195
196
197def ensure_pod_is_valid_after_unpickling(pod: V1Pod) -> V1Pod | None:
198 """
199 Convert pod to json and back so that pod is safe.
200
201 The pod_override in executor_config is a V1Pod object.
202 Such objects created with one k8s version, when unpickled in
203 an env with upgraded k8s version, may blow up when
204 `to_dict` is called, because openapi client code gen calls
205 getattr on all attrs in openapi_types for each object, and when
206 new attrs are added to that list, getattr will fail.
207
208 Here we re-serialize it to ensure it is not going to blow up.
209
210 :meta private:
211 """
212 try:
213 # if to_dict works, the pod is fine
214 pod.to_dict()
215 return pod
216 except AttributeError:
217 pass
218 try:
219 from kubernetes.client.models.v1_pod import V1Pod
220 except ImportError:
221 return None
222 if not isinstance(pod, V1Pod):
223 return None
224 try:
225 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
226
227 # now we actually reserialize / deserialize the pod
228 pod_dict = sanitize_for_serialization(pod)
229 return PodGenerator.deserialize_model_dict(pod_dict)
230 except Exception:
231 return None
232
233
234class ExecutorConfigType(PickleType):
235 """
236 Adds special handling for K8s executor config.
237
238 If we unpickle a k8s object that was pickled under an earlier k8s library version, then
239 the unpickled object may throw an error when to_dict is called. To be more tolerant of
240 version changes we convert to JSON using Airflow's serializer before pickling.
241 """
242
243 cache_ok = True
244
245 def bind_processor(self, dialect):
246 from airflow.serialization.serialized_objects import BaseSerialization
247
248 super_process = super().bind_processor(dialect)
249
250 def process(value):
251 val_copy = copy.copy(value)
252 if isinstance(val_copy, dict) and "pod_override" in val_copy:
253 val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"])
254 return super_process(val_copy)
255
256 return process
257
258 def result_processor(self, dialect, coltype):
259 from airflow.serialization.serialized_objects import BaseSerialization
260
261 super_process = super().result_processor(dialect, coltype)
262
263 def process(value):
264 value = super_process(value) # unpickle
265
266 if isinstance(value, dict) and "pod_override" in value:
267 pod_override = value["pod_override"]
268
269 if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
270 # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
271 value["pod_override"] = BaseSerialization.deserialize(pod_override)
272 else:
273 # backcompat path
274 # we no longer pickle raw pods but this code may be reached
275 # when accessing executor configs created in a prior version
276 new_pod = ensure_pod_is_valid_after_unpickling(pod_override)
277 if new_pod:
278 value["pod_override"] = new_pod
279 return value
280
281 return process
282
283 def compare_values(self, x, y):
284 """
285 Compare x and y using self.comparator if available. Else, use __eq__.
286
287 The TaskInstance.executor_config attribute is a pickled object that may contain kubernetes objects.
288
289 If the installed library version has changed since the object was originally pickled,
290 due to the underlying ``__eq__`` method on these objects (which converts them to JSON),
291 we may encounter attribute errors. In this case we should replace the stored object.
292
293 From https://github.com/apache/airflow/pull/24356 we use our serializer to store
294 k8s objects, but there could still be raw pickled k8s objects in the database,
295 stored from earlier version, so we still compare them defensively here.
296 """
297 if self.comparator:
298 return self.comparator(x, y)
299 try:
300 return x == y
301 except AttributeError:
302 return False
303
304
305def nulls_first(col: ColumnElement, session: Session) -> ColumnElement:
306 """
307 Specify *NULLS FIRST* to the column ordering.
308
309 This is only done to Postgres, currently the only backend that supports it.
310 Other databases do not need it since NULL values are considered lower than
311 any other values, and appear first when the order is ASC (ascending).
312 """
313 if get_dialect_name(session) == "postgresql":
314 return nullsfirst(col)
315 return col
316
317
318USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True)
319
320
321def with_row_locks(
322 query: Select,
323 session: Session,
324 *,
325 nowait: bool = False,
326 skip_locked: bool = False,
327 key_share: bool = True,
328 **kwargs,
329) -> Select:
330 """
331 Apply with_for_update to the SQLAlchemy query if row level locking is in use.
332
333 This wrapper is needed so we don't use the syntax on unsupported database
334 engines. In particular, MySQL (prior to 8.0) and MariaDB do not support
335 row locking, where we do not support nor recommend running HA scheduler. If
336 a user ignores this and tries anyway, everything will still work, just
337 slightly slower in some circumstances.
338
339 See https://jira.mariadb.org/browse/MDEV-13115
340
341 :param query: An SQLAlchemy Query object
342 :param session: ORM Session
343 :param nowait: If set to True, will pass NOWAIT to supported database backends.
344 :param skip_locked: If set to True, will pass SKIP LOCKED to supported database backends.
345 :param key_share: If true, will lock with FOR KEY SHARE UPDATE (at least on postgres).
346 :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc)
347 :return: updated query
348 """
349 try:
350 dialect_name = get_dialect_name(session)
351 except ValueError:
352 return query
353 if not dialect_name:
354 return query
355
356 # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it.
357 if not USE_ROW_LEVEL_LOCKING:
358 return query
359 if dialect_name == "mysql" and not getattr(
360 session.bind.dialect if session.bind else None, "supports_for_update_of", False
361 ):
362 return query
363 if nowait:
364 kwargs["nowait"] = True
365 if skip_locked:
366 kwargs["skip_locked"] = True
367 if key_share:
368 kwargs["key_share"] = True
369 return query.with_for_update(**kwargs)
370
371
372@contextlib.contextmanager
373def lock_rows(query: Select, session: Session) -> Generator[None, None, None]:
374 """
375 Lock database rows during the context manager block.
376
377 This is a convenient method for ``with_row_locks`` when we don't need the
378 locked rows.
379
380 :meta private:
381 """
382 locked_rows = with_row_locks(query, session)
383 yield
384 del locked_rows
385
386
387class CommitProhibitorGuard:
388 """Context manager class that powers prohibit_commit."""
389
390 expected_commit = False
391
392 def __init__(self, session: Session):
393 self.session = session
394
395 def _validate_commit(self, _):
396 if self.expected_commit:
397 self.expected_commit = False
398 return
399 raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!")
400
401 def __enter__(self) -> Self:
402 event.listen(self.session, "before_commit", self._validate_commit)
403 return self
404
405 def __exit__(self, *exc_info):
406 event.remove(self.session, "before_commit", self._validate_commit)
407
408 def commit(self):
409 """
410 Commit the session.
411
412 This is the required way to commit when the guard is in scope
413 """
414 self.expected_commit = True
415 self.session.commit()
416
417
418def prohibit_commit(session):
419 """
420 Return a context manager that will disallow any commit that isn't done via the context manager.
421
422 The aim of this is to ensure that transaction lifetime is strictly controlled which is especially
423 important in the core scheduler loop. Any commit on the session that is _not_ via this context manager
424 will result in RuntimeError
425
426 Example usage:
427
428 .. code:: python
429
430 with prohibit_commit(session) as guard:
431 # ... do something with session
432 guard.commit()
433
434 # This would throw an error
435 # session.commit()
436 """
437 return CommitProhibitorGuard(session)
438
439
440def is_lock_not_available_error(error: OperationalError):
441 """Check if the Error is about not being able to acquire lock."""
442 # DB specific error codes:
443 # Postgres: 55P03
444 # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT
445 # is set.'
446 # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction
447 # (when NOWAIT isn't available)
448 db_err_code = getattr(error.orig, "pgcode", None) or (
449 error.orig.args[0] if error.orig and error.orig.args else None
450 )
451
452 # We could test if error.orig is an instance of
453 # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves
454 # importing it. This doesn't
455 if db_err_code in ("55P03", 1205, 3572):
456 return True
457 return False
458
459
460def make_dialect_kwarg(dialect: str) -> dict[str, str | Iterable[str]]:
461 """Create an SQLAlchemy-version-aware dialect keyword argument."""
462 return {"dialect_names": (dialect,)}