1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import contextlib
21import copy
22import datetime
23import json
24import logging
25from importlib import metadata
26from typing import TYPE_CHECKING, Any, Generator, Iterable, overload
27
28from dateutil import relativedelta
29from packaging import version
30from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_
31from sqlalchemy.dialects import mysql
32from sqlalchemy.types import JSON, Text, TypeDecorator
33
34from airflow.configuration import conf
35from airflow.serialization.enums import Encoding
36from airflow.utils.timezone import make_naive, utc
37
38if TYPE_CHECKING:
39 from kubernetes.client.models.v1_pod import V1Pod
40 from sqlalchemy.exc import OperationalError
41 from sqlalchemy.orm import Query, Session
42 from sqlalchemy.sql import ColumnElement, Select
43 from sqlalchemy.sql.expression import ColumnOperators
44 from sqlalchemy.types import TypeEngine
45
46log = logging.getLogger(__name__)
47
48
49class UtcDateTime(TypeDecorator):
50 """
51 Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` option, with some differences.
52
53 - Never silently take naive :class:`~datetime.datetime`, instead it
54 always raise :exc:`ValueError` unless time zone aware value.
55 - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo`
56 is always converted to UTC.
57 - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`,
58 it never return naive :class:`~datetime.datetime`, but time zone
59 aware value, even with SQLite or MySQL.
60 - Always returns TIMESTAMP in UTC.
61 """
62
63 impl = TIMESTAMP(timezone=True)
64
65 cache_ok = True
66
67 def process_bind_param(self, value, dialect):
68 if not isinstance(value, datetime.datetime):
69 if value is None:
70 return None
71 raise TypeError(f"expected datetime.datetime, not {value!r}")
72 elif value.tzinfo is None:
73 raise ValueError("naive datetime is disallowed")
74 elif dialect.name == "mysql":
75 # For mysql versions prior 8.0.19 we should send timestamps as naive values in UTC
76 # see: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-literals.html
77 return make_naive(value, timezone=utc)
78 return value.astimezone(utc)
79
80 def process_result_value(self, value, dialect):
81 """
82 Process DateTimes from the DB making sure to always return UTC.
83
84 Not using timezone.convert_to_utc as that converts to configured TIMEZONE
85 while the DB might be running with some other setting. We assume UTC
86 datetimes in the database.
87 """
88 if value is not None:
89 if value.tzinfo is None:
90 value = value.replace(tzinfo=utc)
91 else:
92 value = value.astimezone(utc)
93
94 return value
95
96 def load_dialect_impl(self, dialect):
97 if dialect.name == "mysql":
98 return mysql.TIMESTAMP(fsp=6)
99 return super().load_dialect_impl(dialect)
100
101
102class ExtendedJSON(TypeDecorator):
103 """
104 A version of the JSON column that uses the Airflow extended JSON serialization.
105
106 See airflow.serialization.
107 """
108
109 impl = Text
110
111 cache_ok = True
112
113 def load_dialect_impl(self, dialect) -> TypeEngine:
114 return dialect.type_descriptor(JSON)
115
116 def process_bind_param(self, value, dialect):
117 from airflow.serialization.serialized_objects import BaseSerialization
118
119 if value is None:
120 return None
121
122 return BaseSerialization.serialize(value)
123
124 def process_result_value(self, value, dialect):
125 from airflow.serialization.serialized_objects import BaseSerialization
126
127 if value is None:
128 return None
129
130 return BaseSerialization.deserialize(value)
131
132
133def sanitize_for_serialization(obj: V1Pod):
134 """
135 Convert pod to dict.... but *safely*.
136
137 When pod objects created with one k8s version are unpickled in a python
138 env with a more recent k8s version (in which the object attrs may have
139 changed) the unpickled obj may throw an error because the attr
140 expected on new obj may not be there on the unpickled obj.
141
142 This function still converts the pod to a dict; the only difference is
143 it populates missing attrs with None. You may compare with
144 https://github.com/kubernetes-client/python/blob/5a96bbcbe21a552cc1f9cda13e0522fafb0dbac8/kubernetes/client/api_client.py#L202
145
146 If obj is None, return None.
147 If obj is str, int, long, float, bool, return directly.
148 If obj is datetime.datetime, datetime.date
149 convert to string in iso8601 format.
150 If obj is list, sanitize each element in the list.
151 If obj is dict, return the dict.
152 If obj is OpenAPI model, return the properties dict.
153
154 :param obj: The data to serialize.
155 :return: The serialized form of data.
156
157 :meta private:
158 """
159 if obj is None:
160 return None
161 elif isinstance(obj, (float, bool, bytes, str, int)):
162 return obj
163 elif isinstance(obj, list):
164 return [sanitize_for_serialization(sub_obj) for sub_obj in obj]
165 elif isinstance(obj, tuple):
166 return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj)
167 elif isinstance(obj, (datetime.datetime, datetime.date)):
168 return obj.isoformat()
169
170 if isinstance(obj, dict):
171 obj_dict = obj
172 else:
173 obj_dict = {
174 obj.attribute_map[attr]: getattr(obj, attr)
175 for attr, _ in obj.openapi_types.items()
176 # below is the only line we change, and we just add default=None for getattr
177 if getattr(obj, attr, None) is not None
178 }
179
180 return {key: sanitize_for_serialization(val) for key, val in obj_dict.items()}
181
182
183def ensure_pod_is_valid_after_unpickling(pod: V1Pod) -> V1Pod | None:
184 """
185 Convert pod to json and back so that pod is safe.
186
187 The pod_override in executor_config is a V1Pod object.
188 Such objects created with one k8s version, when unpickled in
189 an env with upgraded k8s version, may blow up when
190 `to_dict` is called, because openapi client code gen calls
191 getattr on all attrs in openapi_types for each object, and when
192 new attrs are added to that list, getattr will fail.
193
194 Here we re-serialize it to ensure it is not going to blow up.
195
196 :meta private:
197 """
198 try:
199 # if to_dict works, the pod is fine
200 pod.to_dict()
201 return pod
202 except AttributeError:
203 pass
204 try:
205 from kubernetes.client.models.v1_pod import V1Pod
206 except ImportError:
207 return None
208 if not isinstance(pod, V1Pod):
209 return None
210 try:
211 try:
212 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
213 except ImportError:
214 from airflow.kubernetes.pre_7_4_0_compatibility.pod_generator import ( # type: ignore[assignment]
215 PodGenerator,
216 )
217 # now we actually reserialize / deserialize the pod
218 pod_dict = sanitize_for_serialization(pod)
219 return PodGenerator.deserialize_model_dict(pod_dict)
220 except Exception:
221 return None
222
223
224class ExecutorConfigType(PickleType):
225 """
226 Adds special handling for K8s executor config.
227
228 If we unpickle a k8s object that was pickled under an earlier k8s library version, then
229 the unpickled object may throw an error when to_dict is called. To be more tolerant of
230 version changes we convert to JSON using Airflow's serializer before pickling.
231 """
232
233 cache_ok = True
234
235 def bind_processor(self, dialect):
236 from airflow.serialization.serialized_objects import BaseSerialization
237
238 super_process = super().bind_processor(dialect)
239
240 def process(value):
241 val_copy = copy.copy(value)
242 if isinstance(val_copy, dict) and "pod_override" in val_copy:
243 val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"])
244 return super_process(val_copy)
245
246 return process
247
248 def result_processor(self, dialect, coltype):
249 from airflow.serialization.serialized_objects import BaseSerialization
250
251 super_process = super().result_processor(dialect, coltype)
252
253 def process(value):
254 value = super_process(value) # unpickle
255
256 if isinstance(value, dict) and "pod_override" in value:
257 pod_override = value["pod_override"]
258
259 if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
260 # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
261 value["pod_override"] = BaseSerialization.deserialize(pod_override)
262 else:
263 # backcompat path
264 # we no longer pickle raw pods but this code may be reached
265 # when accessing executor configs created in a prior version
266 new_pod = ensure_pod_is_valid_after_unpickling(pod_override)
267 if new_pod:
268 value["pod_override"] = new_pod
269 return value
270
271 return process
272
273 def compare_values(self, x, y):
274 """
275 Compare x and y using self.comparator if available. Else, use __eq__.
276
277 The TaskInstance.executor_config attribute is a pickled object that may contain kubernetes objects.
278
279 If the installed library version has changed since the object was originally pickled,
280 due to the underlying ``__eq__`` method on these objects (which converts them to JSON),
281 we may encounter attribute errors. In this case we should replace the stored object.
282
283 From https://github.com/apache/airflow/pull/24356 we use our serializer to store
284 k8s objects, but there could still be raw pickled k8s objects in the database,
285 stored from earlier version, so we still compare them defensively here.
286 """
287 if self.comparator:
288 return self.comparator(x, y)
289 else:
290 try:
291 return x == y
292 except AttributeError:
293 return False
294
295
296class Interval(TypeDecorator):
297 """Base class representing a time interval."""
298
299 impl = Text
300
301 cache_ok = True
302
303 attr_keys = {
304 datetime.timedelta: ("days", "seconds", "microseconds"),
305 relativedelta.relativedelta: (
306 "years",
307 "months",
308 "days",
309 "leapdays",
310 "hours",
311 "minutes",
312 "seconds",
313 "microseconds",
314 "year",
315 "month",
316 "day",
317 "hour",
318 "minute",
319 "second",
320 "microsecond",
321 ),
322 }
323
324 def process_bind_param(self, value, dialect):
325 if isinstance(value, tuple(self.attr_keys)):
326 attrs = {key: getattr(value, key) for key in self.attr_keys[type(value)]}
327 return json.dumps({"type": type(value).__name__, "attrs": attrs})
328 return json.dumps(value)
329
330 def process_result_value(self, value, dialect):
331 if not value:
332 return value
333 data = json.loads(value)
334 if isinstance(data, dict):
335 type_map = {key.__name__: key for key in self.attr_keys}
336 return type_map[data["type"]](**data["attrs"])
337 return data
338
339
340def nulls_first(col, session: Session) -> dict[str, Any]:
341 """Specify *NULLS FIRST* to the column ordering.
342
343 This is only done to Postgres, currently the only backend that supports it.
344 Other databases do not need it since NULL values are considered lower than
345 any other values, and appear first when the order is ASC (ascending).
346 """
347 if session.bind.dialect.name == "postgresql":
348 return nullsfirst(col)
349 else:
350 return col
351
352
353USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True)
354
355
356def with_row_locks(
357 query: Query,
358 session: Session,
359 *,
360 nowait: bool = False,
361 skip_locked: bool = False,
362 **kwargs,
363) -> Query:
364 """
365 Apply with_for_update to the SQLAlchemy query if row level locking is in use.
366
367 This wrapper is needed so we don't use the syntax on unsupported database
368 engines. In particular, MySQL (prior to 8.0) and MariaDB do not support
369 row locking, where we do not support nor recommend running HA scheduler. If
370 a user ignores this and tries anyway, everything will still work, just
371 slightly slower in some circumstances.
372
373 See https://jira.mariadb.org/browse/MDEV-13115
374
375 :param query: An SQLAlchemy Query object
376 :param session: ORM Session
377 :param nowait: If set to True, will pass NOWAIT to supported database backends.
378 :param skip_locked: If set to True, will pass SKIP LOCKED to supported database backends.
379 :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc)
380 :return: updated query
381 """
382 dialect = session.bind.dialect
383
384 # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it.
385 if not USE_ROW_LEVEL_LOCKING:
386 return query
387 if dialect.name == "mysql" and not dialect.supports_for_update_of:
388 return query
389 if nowait:
390 kwargs["nowait"] = True
391 if skip_locked:
392 kwargs["skip_locked"] = True
393 return query.with_for_update(**kwargs)
394
395
396@contextlib.contextmanager
397def lock_rows(query: Query, session: Session) -> Generator[None, None, None]:
398 """Lock database rows during the context manager block.
399
400 This is a convenient method for ``with_row_locks`` when we don't need the
401 locked rows.
402
403 :meta private:
404 """
405 locked_rows = with_row_locks(query, session)
406 yield
407 del locked_rows
408
409
410class CommitProhibitorGuard:
411 """Context manager class that powers prohibit_commit."""
412
413 expected_commit = False
414
415 def __init__(self, session: Session):
416 self.session = session
417
418 def _validate_commit(self, _):
419 if self.expected_commit:
420 self.expected_commit = False
421 return
422 raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!")
423
424 def __enter__(self):
425 event.listen(self.session, "before_commit", self._validate_commit)
426 return self
427
428 def __exit__(self, *exc_info):
429 event.remove(self.session, "before_commit", self._validate_commit)
430
431 def commit(self):
432 """
433 Commit the session.
434
435 This is the required way to commit when the guard is in scope
436 """
437 self.expected_commit = True
438 self.session.commit()
439
440
441def prohibit_commit(session):
442 """
443 Return a context manager that will disallow any commit that isn't done via the context manager.
444
445 The aim of this is to ensure that transaction lifetime is strictly controlled which is especially
446 important in the core scheduler loop. Any commit on the session that is _not_ via this context manager
447 will result in RuntimeError
448
449 Example usage:
450
451 .. code:: python
452
453 with prohibit_commit(session) as guard:
454 # ... do something with session
455 guard.commit()
456
457 # This would throw an error
458 # session.commit()
459 """
460 return CommitProhibitorGuard(session)
461
462
463def is_lock_not_available_error(error: OperationalError):
464 """Check if the Error is about not being able to acquire lock."""
465 # DB specific error codes:
466 # Postgres: 55P03
467 # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT
468 # is set.'
469 # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction
470 # (when NOWAIT isn't available)
471 db_err_code = getattr(error.orig, "pgcode", None) or error.orig.args[0]
472
473 # We could test if error.orig is an instance of
474 # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves
475 # importing it. This doesn't
476 if db_err_code in ("55P03", 1205, 3572):
477 return True
478 return False
479
480
481@overload
482def tuple_in_condition(
483 columns: tuple[ColumnElement, ...],
484 collection: Iterable[Any],
485) -> ColumnOperators: ...
486
487
488@overload
489def tuple_in_condition(
490 columns: tuple[ColumnElement, ...],
491 collection: Select,
492 *,
493 session: Session,
494) -> ColumnOperators: ...
495
496
497def tuple_in_condition(
498 columns: tuple[ColumnElement, ...],
499 collection: Iterable[Any] | Select,
500 *,
501 session: Session | None = None,
502) -> ColumnOperators:
503 """
504 Generate a tuple-in-collection operator to use in ``.where()``.
505
506 For most SQL backends, this generates a simple ``([col, ...]) IN [condition]``
507 clause.
508
509 :meta private:
510 """
511 return tuple_(*columns).in_(collection)
512
513
514@overload
515def tuple_not_in_condition(
516 columns: tuple[ColumnElement, ...],
517 collection: Iterable[Any],
518) -> ColumnOperators: ...
519
520
521@overload
522def tuple_not_in_condition(
523 columns: tuple[ColumnElement, ...],
524 collection: Select,
525 *,
526 session: Session,
527) -> ColumnOperators: ...
528
529
530def tuple_not_in_condition(
531 columns: tuple[ColumnElement, ...],
532 collection: Iterable[Any] | Select,
533 *,
534 session: Session | None = None,
535) -> ColumnOperators:
536 """
537 Generate a tuple-not-in-collection operator to use in ``.where()``.
538
539 This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
540
541 :meta private:
542 """
543 return tuple_(*columns).not_in(collection)
544
545
546def get_orm_mapper():
547 """Get the correct ORM mapper for the installed SQLAlchemy version."""
548 import sqlalchemy.orm.mapper
549
550 return sqlalchemy.orm.mapper if is_sqlalchemy_v1() else sqlalchemy.orm.Mapper
551
552
553def is_sqlalchemy_v1() -> bool:
554 return version.parse(metadata.version("sqlalchemy")).major == 1