Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/serialization/serialized_objects.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Serialized DAG and BaseOperator."""
19from __future__ import annotations
21import collections.abc
22import datetime
23import enum
24import inspect
25import logging
26import warnings
27import weakref
28from dataclasses import dataclass
29from inspect import signature
30from textwrap import dedent
31from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, NamedTuple, Union
33import attrs
34import lazy_object_proxy
35from dateutil import relativedelta
36from pendulum.tz.timezone import FixedTimezone, Timezone
38from airflow.compat.functools import cache
39from airflow.configuration import conf
40from airflow.datasets import Dataset, DatasetAll, DatasetAny
41from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError, TaskDeferred
42from airflow.jobs.job import Job
43from airflow.models.baseoperator import BaseOperator
44from airflow.models.connection import Connection
45from airflow.models.dag import DAG, DagModel, create_timetable
46from airflow.models.dagrun import DagRun
47from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key
48from airflow.models.mappedoperator import MappedOperator
49from airflow.models.param import Param, ParamsDict
50from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
51from airflow.models.tasklog import LogTemplate
52from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
53from airflow.providers_manager import ProvidersManager
54from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
55from airflow.serialization.helpers import serialize_template_field
56from airflow.serialization.json_schema import load_dag_schema
57from airflow.serialization.pydantic.dag import DagModelPydantic
58from airflow.serialization.pydantic.dag_run import DagRunPydantic
59from airflow.serialization.pydantic.dataset import DatasetPydantic
60from airflow.serialization.pydantic.job import JobPydantic
61from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
62from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
63from airflow.settings import _ENABLE_AIP_44, DAGS_FOLDER, json
64from airflow.task.priority_strategy import (
65 PriorityWeightStrategy,
66 airflow_priority_weight_strategies,
67 airflow_priority_weight_strategies_classes,
68)
69from airflow.triggers.base import BaseTrigger
70from airflow.utils.code_utils import get_python_source
71from airflow.utils.context import Context, OutletEventAccessor, OutletEventAccessors
72from airflow.utils.docs import get_docs_url
73from airflow.utils.helpers import exactly_one
74from airflow.utils.module_loading import import_string, qualname
75from airflow.utils.operator_resources import Resources
76from airflow.utils.task_group import MappedTaskGroup, TaskGroup
77from airflow.utils.timezone import from_timestamp, parse_timezone
78from airflow.utils.types import NOTSET, ArgNotSet
80if TYPE_CHECKING:
81 from inspect import Parameter
83 from airflow.models.baseoperatorlink import BaseOperatorLink
84 from airflow.models.expandinput import ExpandInput
85 from airflow.models.operator import Operator
86 from airflow.models.taskmixin import DAGNode
87 from airflow.serialization.json_schema import Validator
88 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
89 from airflow.timetables.base import Timetable
90 from airflow.utils.pydantic import BaseModel
92 HAS_KUBERNETES: bool
93 try:
94 from kubernetes.client import models as k8s # noqa: TCH004
96 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TCH004
97 except ImportError:
98 pass
100log = logging.getLogger(__name__)
102_OPERATOR_EXTRA_LINKS: set[str] = {
103 "airflow.operators.trigger_dagrun.TriggerDagRunLink",
104 "airflow.sensors.external_task.ExternalDagLink",
105 # Deprecated names, so that existing serialized dags load straight away.
106 "airflow.sensors.external_task.ExternalTaskSensorLink",
107 "airflow.operators.dagrun_operator.TriggerDagRunLink",
108 "airflow.sensors.external_task_sensor.ExternalTaskSensorLink",
109}
112@cache
113def get_operator_extra_links() -> set[str]:
114 """
115 Get the operator extra links.
117 This includes both the built-in ones, and those come from the providers.
118 """
119 _OPERATOR_EXTRA_LINKS.update(ProvidersManager().extra_links_class_names)
120 return _OPERATOR_EXTRA_LINKS
123@cache
124def _get_default_mapped_partial() -> dict[str, Any]:
125 """
126 Get default partial kwargs in a mapped operator.
128 This is used to simplify a serialized mapped operator by excluding default
129 values supplied in the implementation from the serialized dict. Since those
130 are defaults, they are automatically supplied on de-serialization, so we
131 don't need to store them.
132 """
133 # Use the private _expand() method to avoid the empty kwargs check.
134 default = BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs
135 return BaseSerialization.serialize(default)[Encoding.VAR]
138def encode_relativedelta(var: relativedelta.relativedelta) -> dict[str, Any]:
139 """Encode a relativedelta object."""
140 encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v}
141 if var.weekday and var.weekday.n:
142 # Every n'th Friday for example
143 encoded["weekday"] = [var.weekday.weekday, var.weekday.n]
144 elif var.weekday:
145 encoded["weekday"] = [var.weekday.weekday]
146 return encoded
149def decode_relativedelta(var: dict[str, Any]) -> relativedelta.relativedelta:
150 """Dencode a relativedelta object."""
151 if "weekday" in var:
152 var["weekday"] = relativedelta.weekday(*var["weekday"]) # type: ignore
153 return relativedelta.relativedelta(**var)
156def encode_timezone(var: Timezone | FixedTimezone) -> str | int:
157 """
158 Encode a Pendulum Timezone for serialization.
160 Airflow only supports timezone objects that implements Pendulum's Timezone
161 interface. We try to keep as much information as possible to make conversion
162 round-tripping possible (see ``decode_timezone``). We need to special-case
163 UTC; Pendulum implements it as a FixedTimezone (i.e. it gets encoded as
164 0 without the special case), but passing 0 into ``pendulum.timezone`` does
165 not give us UTC (but ``+00:00``).
166 """
167 if isinstance(var, FixedTimezone):
168 if var.offset == 0:
169 return "UTC"
170 return var.offset
171 if isinstance(var, Timezone):
172 return var.name
173 raise ValueError(
174 f"DAG timezone should be a pendulum.tz.Timezone, not {var!r}. "
175 f"See {get_docs_url('timezone.html#time-zone-aware-dags')}"
176 )
179def decode_timezone(var: str | int) -> Timezone | FixedTimezone:
180 """Decode a previously serialized Pendulum Timezone."""
181 return parse_timezone(var)
184def _get_registered_timetable(importable_string: str) -> type[Timetable] | None:
185 from airflow import plugins_manager
187 if importable_string.startswith("airflow.timetables."):
188 return import_string(importable_string)
189 plugins_manager.initialize_timetables_plugins()
190 if plugins_manager.timetable_classes:
191 return plugins_manager.timetable_classes.get(importable_string)
192 else:
193 return None
196def _get_registered_priority_weight_strategy(importable_string: str) -> type[PriorityWeightStrategy] | None:
197 from airflow import plugins_manager
199 if importable_string in airflow_priority_weight_strategies:
200 return airflow_priority_weight_strategies[importable_string]
201 plugins_manager.initialize_priority_weight_strategy_plugins()
202 if plugins_manager.priority_weight_strategy_classes:
203 return plugins_manager.priority_weight_strategy_classes.get(importable_string)
204 else:
205 return None
208class _TimetableNotRegistered(ValueError):
209 def __init__(self, type_string: str) -> None:
210 self.type_string = type_string
212 def __str__(self) -> str:
213 return (
214 f"Timetable class {self.type_string!r} is not registered or "
215 "you have a top level database access that disrupted the session. "
216 "Please check the airflow best practices documentation."
217 )
220class _PriorityWeightStrategyNotRegistered(AirflowException):
221 def __init__(self, type_string: str) -> None:
222 self.type_string = type_string
224 def __str__(self) -> str:
225 return (
226 f"Priority weight strategy class {self.type_string!r} is not registered or "
227 "you have a top level database access that disrupted the session. "
228 "Please check the airflow best practices documentation."
229 )
232def encode_timetable(var: Timetable) -> dict[str, Any]:
233 """
234 Encode a timetable instance.
236 This delegates most of the serialization work to the type, so the behavior
237 can be completely controlled by a custom subclass.
239 :meta private:
240 """
241 timetable_class = type(var)
242 importable_string = qualname(timetable_class)
243 if _get_registered_timetable(importable_string) is None:
244 raise _TimetableNotRegistered(importable_string)
245 return {Encoding.TYPE: importable_string, Encoding.VAR: var.serialize()}
248def decode_timetable(var: dict[str, Any]) -> Timetable:
249 """
250 Decode a previously serialized timetable.
252 Most of the deserialization logic is delegated to the actual type, which
253 we import from string.
255 :meta private:
256 """
257 importable_string = var[Encoding.TYPE]
258 timetable_class = _get_registered_timetable(importable_string)
259 if timetable_class is None:
260 raise _TimetableNotRegistered(importable_string)
261 return timetable_class.deserialize(var[Encoding.VAR])
264def encode_priority_weight_strategy(var: PriorityWeightStrategy) -> str:
265 """
266 Encode a priority weight strategy instance.
268 In this version, we only store the importable string, so the class should not wait
269 for any parameters to be passed to it. If you need to store the parameters, you
270 should store them in the class itself.
271 """
272 priority_weight_strategy_class = type(var)
273 if priority_weight_strategy_class in airflow_priority_weight_strategies_classes:
274 return airflow_priority_weight_strategies_classes[priority_weight_strategy_class]
275 importable_string = qualname(priority_weight_strategy_class)
276 if _get_registered_priority_weight_strategy(importable_string) is None:
277 raise _PriorityWeightStrategyNotRegistered(importable_string)
278 return importable_string
281def decode_priority_weight_strategy(var: str) -> PriorityWeightStrategy:
282 """
283 Decode a previously serialized priority weight strategy.
285 In this version, we only store the importable string, so we just need to get the class
286 from the dictionary of registered classes and instantiate it with no parameters.
287 """
288 priority_weight_strategy_class = _get_registered_priority_weight_strategy(var)
289 if priority_weight_strategy_class is None:
290 raise _PriorityWeightStrategyNotRegistered(var)
291 return priority_weight_strategy_class()
294class _XComRef(NamedTuple):
295 """
296 Store info needed to create XComArg.
298 We can't turn it in to a XComArg until we've loaded _all_ the tasks, so when
299 deserializing an operator, we need to create something in its place, and
300 post-process it in ``deserialize_dag``.
301 """
303 data: dict
305 def deref(self, dag: DAG) -> XComArg:
306 return deserialize_xcom_arg(self.data, dag)
309# These two should be kept in sync. Note that these are intentionally not using
310# the type declarations in expandinput.py so we always remember to update
311# serialization logic when adding new ExpandInput variants. If you add things to
312# the unions, be sure to update _ExpandInputRef to match.
313_ExpandInputOriginalValue = Union[
314 # For .expand(**kwargs).
315 Mapping[str, Any],
316 # For expand_kwargs(arg).
317 XComArg,
318 Collection[Union[XComArg, Mapping[str, Any]]],
319]
320_ExpandInputSerializedValue = Union[
321 # For .expand(**kwargs).
322 Mapping[str, Any],
323 # For expand_kwargs(arg).
324 _XComRef,
325 Collection[Union[_XComRef, Mapping[str, Any]]],
326]
329class _ExpandInputRef(NamedTuple):
330 """
331 Store info needed to create a mapped operator's expand input.
333 This references a ``ExpandInput`` type, but replaces ``XComArg`` objects
334 with ``_XComRef`` (see documentation on the latter type for reasoning).
335 """
337 key: str
338 value: _ExpandInputSerializedValue
340 @classmethod
341 def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
342 """
343 Validate we've covered all ``ExpandInput.value`` types.
345 This function does not actually do anything, but is called during
346 serialization so Mypy will *statically* check we have handled all
347 possible ExpandInput cases.
348 """
350 def deref(self, dag: DAG) -> ExpandInput:
351 """
352 De-reference into a concrete ExpandInput object.
354 If you add more cases here, be sure to update _ExpandInputOriginalValue
355 and _ExpandInputSerializedValue to match the logic.
356 """
357 if isinstance(self.value, _XComRef):
358 value: Any = self.value.deref(dag)
359 elif isinstance(self.value, collections.abc.Mapping):
360 value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k, v in self.value.items()}
361 else:
362 value = [v.deref(dag) if isinstance(v, _XComRef) else v for v in self.value]
363 return create_expand_input(self.key, value)
366_orm_to_model = {
367 Job: JobPydantic,
368 TaskInstance: TaskInstancePydantic,
369 DagRun: DagRunPydantic,
370 DagModel: DagModelPydantic,
371 LogTemplate: LogTemplatePydantic,
372}
373_type_to_class = {
374 DAT.BASE_JOB: [JobPydantic, Job],
375 DAT.TASK_INSTANCE: [TaskInstancePydantic, TaskInstance],
376 DAT.DAG_RUN: [DagRunPydantic, DagRun],
377 DAT.DAG_MODEL: [DagModelPydantic, DagModel],
378 DAT.LOG_TEMPLATE: [LogTemplatePydantic, LogTemplate],
379}
380_class_to_type = {cls_: type_ for type_, classes in _type_to_class.items() for cls_ in classes}
383class BaseSerialization:
384 """BaseSerialization provides utils for serialization."""
386 # JSON primitive types.
387 _primitive_types = (int, bool, float, str)
389 # Time types.
390 # datetime.date and datetime.time are converted to strings.
391 _datetime_types = (datetime.datetime,)
393 # Object types that are always excluded in serialization.
394 _excluded_types = (logging.Logger, Connection, type, property)
396 _json_schema: Validator | None = None
398 # Should the extra operator link be loaded via plugins when
399 # de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links
400 # are not loaded to not run User code in Scheduler.
401 _load_operator_extra_links = True
403 _CONSTRUCTOR_PARAMS: dict[str, Parameter] = {}
405 SERIALIZER_VERSION = 1
407 @classmethod
408 def to_json(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> str:
409 """Stringify DAGs and operators contained by var and returns a JSON string of var."""
410 return json.dumps(cls.to_dict(var), ensure_ascii=True)
412 @classmethod
413 def to_dict(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> dict:
414 """Stringify DAGs and operators contained by var and returns a dict of var."""
415 # Don't call on this class directly - only SerializedDAG or
416 # SerializedBaseOperator should be used as the "entrypoint"
417 raise NotImplementedError()
419 @classmethod
420 def from_json(cls, serialized_obj: str) -> BaseSerialization | dict | list | set | tuple:
421 """Deserialize json_str and reconstructs all DAGs and operators it contains."""
422 return cls.from_dict(json.loads(serialized_obj))
424 @classmethod
425 def from_dict(cls, serialized_obj: dict[Encoding, Any]) -> BaseSerialization | dict | list | set | tuple:
426 """Deserialize a dict of type decorators and reconstructs all DAGs and operators it contains."""
427 return cls.deserialize(serialized_obj)
429 @classmethod
430 def validate_schema(cls, serialized_obj: str | dict) -> None:
431 """Validate serialized_obj satisfies JSON schema."""
432 if cls._json_schema is None:
433 raise AirflowException(f"JSON schema of {cls.__name__:s} is not set.")
435 if isinstance(serialized_obj, dict):
436 cls._json_schema.validate(serialized_obj)
437 elif isinstance(serialized_obj, str):
438 cls._json_schema.validate(json.loads(serialized_obj))
439 else:
440 raise TypeError("Invalid type: Only dict and str are supported.")
442 @staticmethod
443 def _encode(x: Any, type_: Any) -> dict[Encoding, Any]:
444 """Encode data by a JSON dict."""
445 return {Encoding.VAR: x, Encoding.TYPE: type_}
447 @classmethod
448 def _is_primitive(cls, var: Any) -> bool:
449 """Primitive types."""
450 return var is None or isinstance(var, cls._primitive_types)
452 @classmethod
453 def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool:
454 """Check if type is excluded from serialization."""
455 if var is None:
456 if not cls._is_constructor_param(attrname, instance):
457 # Any instance attribute, that is not a constructor argument, we exclude None as the default
458 return True
460 return cls._value_is_hardcoded_default(attrname, var, instance)
461 return isinstance(var, cls._excluded_types) or cls._value_is_hardcoded_default(
462 attrname, var, instance
463 )
465 @classmethod
466 def serialize_to_json(
467 cls, object_to_serialize: BaseOperator | MappedOperator | DAG, decorated_fields: set
468 ) -> dict[str, Any]:
469 """Serialize an object to JSON."""
470 serialized_object: dict[str, Any] = {}
471 keys_to_serialize = object_to_serialize.get_serialized_fields()
472 for key in keys_to_serialize:
473 # None is ignored in serialized form and is added back in deserialization.
474 value = getattr(object_to_serialize, key, None)
475 if cls._is_excluded(value, key, object_to_serialize):
476 continue
478 if key == "_operator_name":
479 # when operator_name matches task_type, we can remove
480 # it to reduce the JSON payload
481 task_type = getattr(object_to_serialize, "_task_type", None)
482 if value != task_type:
483 serialized_object[key] = cls.serialize(value)
484 elif key in decorated_fields:
485 serialized_object[key] = cls.serialize(value)
486 elif key == "timetable" and value is not None:
487 serialized_object[key] = encode_timetable(value)
488 elif key == "weight_rule" and value is not None:
489 serialized_object[key] = encode_priority_weight_strategy(value)
490 elif key == "dataset_triggers":
491 serialized_object[key] = cls.serialize(value)
492 else:
493 value = cls.serialize(value)
494 if isinstance(value, dict) and Encoding.TYPE in value:
495 value = value[Encoding.VAR]
496 serialized_object[key] = value
497 return serialized_object
499 @classmethod
500 def serialize(
501 cls, var: Any, *, strict: bool = False, use_pydantic_models: bool = False
502 ) -> Any: # Unfortunately there is no support for recursive types in mypy
503 """
504 Serialize an object; helper function of depth first search for serialization.
506 The serialization protocol is:
508 (1) keeping JSON supported types: primitives, dict, list;
509 (2) encoding other types as ``{TYPE: 'foo', VAR: 'bar'}``, the deserialization
510 step decode VAR according to TYPE;
511 (3) Operator has a special field CLASS to record the original class
512 name for displaying in UI.
514 :meta private:
515 """
516 if use_pydantic_models and not _ENABLE_AIP_44:
517 raise RuntimeError(
518 "Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. "
519 "This parameter will be removed eventually when new serialization is used by AIP-44"
520 )
521 if cls._is_primitive(var):
522 # enum.IntEnum is an int instance, it causes json dumps error so we use its value.
523 if isinstance(var, enum.Enum):
524 return var.value
525 return var
526 elif isinstance(var, dict):
527 return cls._encode(
528 {
529 str(k): cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models)
530 for k, v in var.items()
531 },
532 type_=DAT.DICT,
533 )
534 elif isinstance(var, list):
535 return [cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var]
536 elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and isinstance(var, k8s.V1Pod):
537 json_pod = PodGenerator.serialize_pod(var)
538 return cls._encode(json_pod, type_=DAT.POD)
539 elif isinstance(var, OutletEventAccessors):
540 return cls._encode(
541 cls.serialize(var._dict, strict=strict, use_pydantic_models=use_pydantic_models), # type: ignore[attr-defined]
542 type_=DAT.DATASET_EVENT_ACCESSORS,
543 )
544 elif isinstance(var, OutletEventAccessor):
545 return cls._encode(
546 cls.serialize(var.extra, strict=strict, use_pydantic_models=use_pydantic_models),
547 type_=DAT.DATASET_EVENT_ACCESSOR,
548 )
549 elif isinstance(var, DAG):
550 return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
551 elif isinstance(var, Resources):
552 return var.to_dict()
553 elif isinstance(var, MappedOperator):
554 return cls._encode(SerializedBaseOperator.serialize_mapped_operator(var), type_=DAT.OP)
555 elif isinstance(var, BaseOperator):
556 var._needs_expansion = var.get_needs_expansion()
557 return cls._encode(SerializedBaseOperator.serialize_operator(var), type_=DAT.OP)
558 elif isinstance(var, cls._datetime_types):
559 return cls._encode(var.timestamp(), type_=DAT.DATETIME)
560 elif isinstance(var, datetime.timedelta):
561 return cls._encode(var.total_seconds(), type_=DAT.TIMEDELTA)
562 elif isinstance(var, (Timezone, FixedTimezone)):
563 return cls._encode(encode_timezone(var), type_=DAT.TIMEZONE)
564 elif isinstance(var, relativedelta.relativedelta):
565 return cls._encode(encode_relativedelta(var), type_=DAT.RELATIVEDELTA)
566 elif isinstance(var, (AirflowException, TaskDeferred)) and hasattr(var, "serialize"):
567 exc_cls_name, args, kwargs = var.serialize()
568 return cls._encode(
569 cls.serialize(
570 {"exc_cls_name": exc_cls_name, "args": args, "kwargs": kwargs},
571 use_pydantic_models=use_pydantic_models,
572 strict=strict,
573 ),
574 type_=DAT.AIRFLOW_EXC_SER,
575 )
576 elif isinstance(var, BaseTrigger):
577 return cls._encode(
578 cls.serialize(var.serialize(), use_pydantic_models=use_pydantic_models, strict=strict),
579 type_=DAT.BASE_TRIGGER,
580 )
581 elif callable(var):
582 return str(get_python_source(var))
583 elif isinstance(var, set):
584 # FIXME: casts set to list in customized serialization in future.
585 try:
586 return cls._encode(
587 sorted(
588 cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var
589 ),
590 type_=DAT.SET,
591 )
592 except TypeError:
593 return cls._encode(
594 [cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var],
595 type_=DAT.SET,
596 )
597 elif isinstance(var, tuple):
598 # FIXME: casts tuple to list in customized serialization in future.
599 return cls._encode(
600 [cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) for v in var],
601 type_=DAT.TUPLE,
602 )
603 elif isinstance(var, TaskGroup):
604 return TaskGroupSerialization.serialize_task_group(var)
605 elif isinstance(var, Param):
606 return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
607 elif isinstance(var, XComArg):
608 return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
609 elif isinstance(var, Dataset):
610 return cls._encode({"uri": var.uri, "extra": var.extra}, type_=DAT.DATASET)
611 elif isinstance(var, DatasetAll):
612 return cls._encode(
613 [
614 cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models)
615 for x in var.objects
616 ],
617 type_=DAT.DATASET_ALL,
618 )
619 elif isinstance(var, DatasetAny):
620 return cls._encode(
621 [
622 cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models)
623 for x in var.objects
624 ],
625 type_=DAT.DATASET_ANY,
626 )
627 elif isinstance(var, SimpleTaskInstance):
628 return cls._encode(
629 cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models),
630 type_=DAT.SIMPLE_TASK_INSTANCE,
631 )
632 elif isinstance(var, Connection):
633 return cls._encode(var.to_dict(validate=True), type_=DAT.CONNECTION)
634 elif var.__class__ == Context:
635 d = {}
636 for k, v in var._context.items():
637 obj = cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models)
638 d[str(k)] = obj
639 return cls._encode(d, type_=DAT.TASK_CONTEXT)
640 elif use_pydantic_models and _ENABLE_AIP_44:
642 def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> dict[str, Any]:
643 return model_cls.model_validate(var).model_dump(mode="json") # type: ignore[attr-defined]
645 if var.__class__ in _class_to_type:
646 pyd_mod = _orm_to_model.get(var.__class__, var)
647 mod = _pydantic_model_dump(pyd_mod, var)
648 type_ = _class_to_type[var.__class__]
649 return cls._encode(mod, type_=type_)
650 else:
651 return cls.default_serialization(strict, var)
652 elif isinstance(var, ArgNotSet):
653 return cls._encode(None, type_=DAT.ARG_NOT_SET)
654 else:
655 return cls.default_serialization(strict, var)
657 @classmethod
658 def default_serialization(cls, strict, var) -> str:
659 log.debug("Cast type %s to str in serialization.", type(var))
660 if strict:
661 raise SerializationError("Encountered unexpected type")
662 return str(var)
664 @classmethod
665 def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
666 """
667 Deserialize an object; helper function of depth first search for deserialization.
669 :meta private:
670 """
671 # JSON primitives (except for dict) are not encoded.
672 if use_pydantic_models and not _ENABLE_AIP_44:
673 raise RuntimeError(
674 "Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. "
675 "This parameter will be removed eventually when new serialization is used by AIP-44"
676 )
677 if cls._is_primitive(encoded_var):
678 return encoded_var
679 elif isinstance(encoded_var, list):
680 return [cls.deserialize(v, use_pydantic_models) for v in encoded_var]
682 if not isinstance(encoded_var, dict):
683 raise ValueError(f"The encoded_var should be dict and is {type(encoded_var)}")
684 var = encoded_var[Encoding.VAR]
685 type_ = encoded_var[Encoding.TYPE]
686 if type_ == DAT.TASK_CONTEXT:
687 d = {}
688 for k, v in var.items():
689 if k == "task": # todo: add `_encode` of Operator so we don't need this
690 continue
691 d[k] = cls.deserialize(v, use_pydantic_models=True)
692 d["task"] = d["task_instance"].task # todo: add `_encode` of Operator so we don't need this
693 return Context(**d)
694 elif type_ == DAT.DICT:
695 return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
696 elif type_ == DAT.DATASET_EVENT_ACCESSORS:
697 d = OutletEventAccessors() # type: ignore[assignment]
698 d._dict = cls.deserialize(var) # type: ignore[attr-defined]
699 return d
700 elif type_ == DAT.DATASET_EVENT_ACCESSOR:
701 return OutletEventAccessor(extra=cls.deserialize(var))
702 elif type_ == DAT.DAG:
703 return SerializedDAG.deserialize_dag(var)
704 elif type_ == DAT.OP:
705 return SerializedBaseOperator.deserialize_operator(var)
706 elif type_ == DAT.DATETIME:
707 return from_timestamp(var)
708 elif type_ == DAT.POD:
709 if not _has_kubernetes():
710 raise RuntimeError("Cannot deserialize POD objects without kubernetes libraries installed!")
711 pod = PodGenerator.deserialize_model_dict(var)
712 return pod
713 elif type_ == DAT.TIMEDELTA:
714 return datetime.timedelta(seconds=var)
715 elif type_ == DAT.TIMEZONE:
716 return decode_timezone(var)
717 elif type_ == DAT.RELATIVEDELTA:
718 return decode_relativedelta(var)
719 elif type_ == DAT.AIRFLOW_EXC_SER:
720 deser = cls.deserialize(var, use_pydantic_models=use_pydantic_models)
721 exc_cls_name = deser["exc_cls_name"]
722 args = deser["args"]
723 kwargs = deser["kwargs"]
724 del deser
725 exc_cls = import_string(f"airflow.exceptions.{exc_cls_name}")
726 return exc_cls(*args, **kwargs)
727 elif type_ == DAT.BASE_TRIGGER:
728 tr_cls_name, kwargs = cls.deserialize(var, use_pydantic_models=use_pydantic_models)
729 tr_cls = import_string(tr_cls_name)
730 return tr_cls(**kwargs)
731 elif type_ == DAT.SET:
732 return {cls.deserialize(v, use_pydantic_models) for v in var}
733 elif type_ == DAT.TUPLE:
734 return tuple(cls.deserialize(v, use_pydantic_models) for v in var)
735 elif type_ == DAT.PARAM:
736 return cls._deserialize_param(var)
737 elif type_ == DAT.XCOM_REF:
738 return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
739 elif type_ == DAT.DATASET:
740 return Dataset(**var)
741 elif type_ == DAT.DATASET_ANY:
742 return DatasetAny(*(cls.deserialize(x) for x in var))
743 elif type_ == DAT.DATASET_ALL:
744 return DatasetAll(*(cls.deserialize(x) for x in var))
745 elif type_ == DAT.SIMPLE_TASK_INSTANCE:
746 return SimpleTaskInstance(**cls.deserialize(var))
747 elif type_ == DAT.CONNECTION:
748 return Connection(**var)
749 elif use_pydantic_models and _ENABLE_AIP_44:
750 if type_ == DAT.BASE_JOB:
751 return JobPydantic.model_validate(var)
752 elif type_ == DAT.TASK_INSTANCE:
753 return TaskInstancePydantic.model_validate(var)
754 elif type_ == DAT.DAG_RUN:
755 return DagRunPydantic.model_validate(var)
756 elif type_ == DAT.DAG_MODEL:
757 return DagModelPydantic.model_validate(var)
758 elif type_ == DAT.DATA_SET:
759 return DatasetPydantic.model_validate(var)
760 elif type_ == DAT.LOG_TEMPLATE:
761 return LogTemplatePydantic.model_validate(var)
762 elif type_ == DAT.ARG_NOT_SET:
763 return NOTSET
764 else:
765 raise TypeError(f"Invalid type {type_!s} in deserialization.")
767 _deserialize_datetime = from_timestamp
768 _deserialize_timezone = parse_timezone
770 @classmethod
771 def _deserialize_timedelta(cls, seconds: int) -> datetime.timedelta:
772 return datetime.timedelta(seconds=seconds)
774 @classmethod
775 def _is_constructor_param(cls, attrname: str, instance: Any) -> bool:
776 return attrname in cls._CONSTRUCTOR_PARAMS
778 @classmethod
779 def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) -> bool:
780 """
781 Return true if ``value`` is the hard-coded default for the given attribute.
783 This takes in to account cases where the ``max_active_tasks`` parameter is
784 stored in the ``_max_active_tasks`` attribute.
786 And by using `is` here only and not `==` this copes with the case a
787 user explicitly specifies an attribute with the same "value" as the
788 default. (This is because ``"default" is "default"`` will be False as
789 they are different strings with the same characters.)
791 Also returns True if the value is an empty list or empty dict. This is done
792 to account for the case where the default value of the field is None but has the
793 ``field = field or {}`` set.
794 """
795 if attrname in cls._CONSTRUCTOR_PARAMS and (
796 cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])
797 ):
798 return True
799 return False
801 @classmethod
802 def _serialize_param(cls, param: Param):
803 return {
804 "__class": f"{param.__module__}.{param.__class__.__name__}",
805 "default": cls.serialize(param.value),
806 "description": cls.serialize(param.description),
807 "schema": cls.serialize(param.schema),
808 }
810 @classmethod
811 def _deserialize_param(cls, param_dict: dict):
812 """
813 Workaround to serialize Param on older versions.
815 In 2.2.0, Param attrs were assumed to be json-serializable and were not run through
816 this class's ``serialize`` method. So before running through ``deserialize``,
817 we first verify that it's necessary to do.
818 """
819 class_name = param_dict["__class"]
820 class_: type[Param] = import_string(class_name)
821 attrs = ("default", "description", "schema")
822 kwargs = {}
824 def is_serialized(val):
825 if isinstance(val, dict):
826 return Encoding.TYPE in val
827 if isinstance(val, list):
828 return all(isinstance(item, dict) and Encoding.TYPE in item for item in val)
829 return False
831 for attr in attrs:
832 if attr in param_dict:
833 val = param_dict[attr]
834 if is_serialized(val):
835 val = cls.deserialize(val)
836 kwargs[attr] = val
837 return class_(**kwargs)
839 @classmethod
840 def _serialize_params_dict(cls, params: ParamsDict | dict):
841 """Serialize Params dict for a DAG or task."""
842 serialized_params = {}
843 for k, v in params.items():
844 # TODO: As of now, we would allow serialization of params which are of type Param only.
845 try:
846 class_identity = f"{v.__module__}.{v.__class__.__name__}"
847 except AttributeError:
848 class_identity = ""
849 if class_identity == "airflow.models.param.Param":
850 serialized_params[k] = cls._serialize_param(v)
851 else:
852 raise ValueError(
853 f"Params to a DAG or a Task can be only of type airflow.models.param.Param, "
854 f"but param {k!r} is {v.__class__}"
855 )
856 return serialized_params
858 @classmethod
859 def _deserialize_params_dict(cls, encoded_params: dict) -> ParamsDict:
860 """Deserialize a DAG's Params dict."""
861 op_params = {}
862 for k, v in encoded_params.items():
863 if isinstance(v, dict) and "__class" in v:
864 op_params[k] = cls._deserialize_param(v)
865 else:
866 # Old style params, convert it
867 op_params[k] = Param(v)
869 return ParamsDict(op_params)
872class DependencyDetector:
873 """
874 Detects dependencies between DAGs.
876 :meta private:
877 """
879 @staticmethod
880 def detect_task_dependencies(task: Operator) -> list[DagDependency]:
881 """Detect dependencies caused by tasks."""
882 from airflow.operators.trigger_dagrun import TriggerDagRunOperator
883 from airflow.sensors.external_task import ExternalTaskSensor
885 deps = []
886 if isinstance(task, TriggerDagRunOperator):
887 deps.append(
888 DagDependency(
889 source=task.dag_id,
890 target=getattr(task, "trigger_dag_id"),
891 dependency_type="trigger",
892 dependency_id=task.task_id,
893 )
894 )
895 elif isinstance(task, ExternalTaskSensor):
896 deps.append(
897 DagDependency(
898 source=getattr(task, "external_dag_id"),
899 target=task.dag_id,
900 dependency_type="sensor",
901 dependency_id=task.task_id,
902 )
903 )
904 for obj in task.outlets or []:
905 if isinstance(obj, Dataset):
906 deps.append(
907 DagDependency(
908 source=task.dag_id,
909 target="dataset",
910 dependency_type="dataset",
911 dependency_id=obj.uri,
912 )
913 )
914 return deps
916 @staticmethod
917 def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]:
918 """Detect dependencies set directly on the DAG object."""
919 if not dag:
920 return
921 if not dag.dataset_triggers:
922 return
923 for uri, _ in dag.dataset_triggers.iter_datasets():
924 yield DagDependency(
925 source="dataset",
926 target=dag.dag_id,
927 dependency_type="dataset",
928 dependency_id=uri,
929 )
932class SerializedBaseOperator(BaseOperator, BaseSerialization):
933 """A JSON serializable representation of operator.
935 All operators are casted to SerializedBaseOperator after deserialization.
936 Class specific attributes used by UI are move to object attributes.
938 Creating a SerializedBaseOperator is a three-step process:
940 1. Instantiate a :class:`SerializedBaseOperator` object.
941 2. Populate attributes with :func:`SerializedBaseOperator.populated_operator`.
942 3. When the task's containing DAG is available, fix references to the DAG
943 with :func:`SerializedBaseOperator.set_task_dag_references`.
944 """
946 _decorated_fields = {"executor_config"}
948 _CONSTRUCTOR_PARAMS = {
949 k: v.default
950 for k, v in signature(BaseOperator.__init__).parameters.items()
951 if v.default is not v.empty
952 }
954 def __init__(self, *args, **kwargs):
955 super().__init__(*args, **kwargs)
956 # task_type is used by UI to display the correct class type, because UI only
957 # receives BaseOperator from deserialized DAGs.
958 self._task_type = "BaseOperator"
959 # Move class attributes into object attributes.
960 self.ui_color = BaseOperator.ui_color
961 self.ui_fgcolor = BaseOperator.ui_fgcolor
962 self.template_ext = BaseOperator.template_ext
963 self.template_fields = BaseOperator.template_fields
964 self.operator_extra_links = BaseOperator.operator_extra_links
966 @property
967 def task_type(self) -> str:
968 # Overwrites task_type of BaseOperator to use _task_type instead of
969 # __class__.__name__.
971 return self._task_type
973 @task_type.setter
974 def task_type(self, task_type: str):
975 self._task_type = task_type
977 @property
978 def operator_name(self) -> str:
979 # Overwrites operator_name of BaseOperator to use _operator_name instead of
980 # __class__.operator_name.
981 return self._operator_name
983 @operator_name.setter
984 def operator_name(self, operator_name: str):
985 self._operator_name = operator_name
987 @classmethod
988 def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]:
989 serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator))
990 # Handle expand_input and op_kwargs_expand_input.
991 expansion_kwargs = op._get_specified_expand_input()
992 if TYPE_CHECKING: # Let Mypy check the input type for us!
993 _ExpandInputRef.validate_expand_input_value(expansion_kwargs.value)
994 serialized_op[op._expand_input_attr] = {
995 "type": get_map_type_key(expansion_kwargs),
996 "value": cls.serialize(expansion_kwargs.value),
997 }
999 # Simplify partial_kwargs by comparing it to the most barebone object.
1000 # Remove all entries that are simply default values.
1001 serialized_partial = serialized_op["partial_kwargs"]
1002 for k, default in _get_default_mapped_partial().items():
1003 try:
1004 v = serialized_partial[k]
1005 except KeyError:
1006 continue
1007 if v == default:
1008 del serialized_partial[k]
1010 serialized_op["_is_mapped"] = True
1011 return serialized_op
1013 @classmethod
1014 def serialize_operator(cls, op: BaseOperator | MappedOperator) -> dict[str, Any]:
1015 return cls._serialize_node(op, include_deps=op.deps is not BaseOperator.deps)
1017 @classmethod
1018 def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) -> dict[str, Any]:
1019 """Serialize operator into a JSON object."""
1020 serialize_op = cls.serialize_to_json(op, cls._decorated_fields)
1022 serialize_op["_task_type"] = getattr(op, "_task_type", type(op).__name__)
1023 serialize_op["_task_module"] = getattr(op, "_task_module", type(op).__module__)
1024 if op.operator_name != serialize_op["_task_type"]:
1025 serialize_op["_operator_name"] = op.operator_name
1027 # Used to determine if an Operator is inherited from EmptyOperator
1028 serialize_op["_is_empty"] = op.inherits_from_empty_operator
1030 if exactly_one(op.start_trigger is not None, op.next_method is not None):
1031 raise AirflowException("start_trigger and next_method should both be set.")
1033 serialize_op["start_trigger"] = op.start_trigger.serialize() if op.start_trigger else None
1034 serialize_op["next_method"] = op.next_method
1036 if op.operator_extra_links:
1037 serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links(
1038 op.operator_extra_links.__get__(op)
1039 if isinstance(op.operator_extra_links, property)
1040 else op.operator_extra_links
1041 )
1043 if include_deps:
1044 serialize_op["deps"] = cls._serialize_deps(op.deps)
1046 # Store all template_fields as they are if there are JSON Serializable
1047 # If not, store them as strings
1048 # And raise an exception if the field is not templateable
1049 forbidden_fields = set(inspect.signature(BaseOperator.__init__).parameters.keys())
1050 # Though allow some of the BaseOperator fields to be templated anyway
1051 forbidden_fields.difference_update({"email"})
1052 if op.template_fields:
1053 for template_field in op.template_fields:
1054 if template_field in forbidden_fields:
1055 raise AirflowException(
1056 dedent(
1057 f"""Cannot template BaseOperator field:
1058 {template_field!r} {op.__class__.__name__=} {op.template_fields=}"""
1059 )
1060 )
1061 value = getattr(op, template_field, None)
1062 if not cls._is_excluded(value, template_field, op):
1063 serialize_op[template_field] = serialize_template_field(value, template_field)
1065 if op.params:
1066 serialize_op["params"] = cls._serialize_params_dict(op.params)
1068 return serialize_op
1070 @classmethod
1071 def _serialize_deps(cls, op_deps: Iterable[BaseTIDep]) -> list[str]:
1072 from airflow import plugins_manager
1074 plugins_manager.initialize_ti_deps_plugins()
1075 if plugins_manager.registered_ti_dep_classes is None:
1076 raise AirflowException("Can not load plugins")
1078 deps = []
1079 for dep in op_deps:
1080 klass = type(dep)
1081 module_name = klass.__module__
1082 qualname = f"{module_name}.{klass.__name__}"
1083 if (
1084 not qualname.startswith("airflow.ti_deps.deps.")
1085 and qualname not in plugins_manager.registered_ti_dep_classes
1086 ):
1087 raise SerializationError(
1088 f"Custom dep class {qualname} not serialized, please register it through plugins."
1089 )
1090 deps.append(qualname)
1091 # deps needs to be sorted here, because op_deps is a set, which is unstable when traversing,
1092 # and the same call may get different results.
1093 # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
1094 return sorted(deps)
1096 @classmethod
1097 def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
1098 """Populate operator attributes with serialized values.
1100 This covers simple attributes that don't reference other things in the
1101 DAG. Setting references (such as ``op.dag`` and task dependencies) is
1102 done in ``set_task_dag_references`` instead, which is called after the
1103 DAG is hydrated.
1104 """
1105 if "label" not in encoded_op:
1106 # Handle deserialization of old data before the introduction of TaskGroup
1107 encoded_op["label"] = encoded_op["task_id"]
1109 # Extra Operator Links defined in Plugins
1110 op_extra_links_from_plugin = {}
1112 if "_operator_name" not in encoded_op:
1113 encoded_op["_operator_name"] = encoded_op["_task_type"]
1115 # We don't want to load Extra Operator links in Scheduler
1116 if cls._load_operator_extra_links:
1117 from airflow import plugins_manager
1119 plugins_manager.initialize_extra_operators_links_plugins()
1121 if plugins_manager.operator_extra_links is None:
1122 raise AirflowException("Can not load plugins")
1124 for ope in plugins_manager.operator_extra_links:
1125 for operator in ope.operators:
1126 if (
1127 operator.__name__ == encoded_op["_task_type"]
1128 and operator.__module__ == encoded_op["_task_module"]
1129 ):
1130 op_extra_links_from_plugin.update({ope.name: ope})
1132 # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
1133 # set the Operator links attribute
1134 # The case for "If OperatorLinks are defined in the operator that is being Serialized"
1135 # is handled in the deserialization loop where it matches k == "_operator_extra_links"
1136 if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
1137 setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values()))
1139 for k, v in encoded_op.items():
1140 # Todo: TODO: Remove in Airflow 3.0 when dummy operator is removed
1141 if k == "_is_dummy":
1142 k = "_is_empty"
1144 if k in ("_outlets", "_inlets"):
1145 # `_outlets` -> `outlets`
1146 k = k[1:]
1147 if k == "_downstream_task_ids":
1148 # Upgrade from old format/name
1149 k = "downstream_task_ids"
1150 if k == "label":
1151 # Label shouldn't be set anymore -- it's computed from task_id now
1152 continue
1153 elif k == "downstream_task_ids":
1154 v = set(v)
1155 elif k == "subdag":
1156 v = SerializedDAG.deserialize_dag(v)
1157 elif k in {"retry_delay", "execution_timeout", "sla", "max_retry_delay"}:
1158 v = cls._deserialize_timedelta(v)
1159 elif k in encoded_op["template_fields"]:
1160 pass
1161 elif k == "resources":
1162 v = Resources.from_dict(v)
1163 elif k.endswith("_date"):
1164 v = cls._deserialize_datetime(v)
1165 elif k == "_operator_extra_links":
1166 if cls._load_operator_extra_links:
1167 op_predefined_extra_links = cls._deserialize_operator_extra_links(v)
1169 # If OperatorLinks with the same name exists, Links via Plugin have higher precedence
1170 op_predefined_extra_links.update(op_extra_links_from_plugin)
1171 else:
1172 op_predefined_extra_links = {}
1174 v = list(op_predefined_extra_links.values())
1175 k = "operator_extra_links"
1177 elif k == "deps":
1178 v = cls._deserialize_deps(v)
1179 elif k == "params":
1180 v = cls._deserialize_params_dict(v)
1181 if op.params: # Merge existing params if needed.
1182 v, new = op.params, v
1183 v.update(new)
1184 elif k == "partial_kwargs":
1185 v = {arg: cls.deserialize(value) for arg, value in v.items()}
1186 elif k in {"expand_input", "op_kwargs_expand_input"}:
1187 v = _ExpandInputRef(v["type"], cls.deserialize(v["value"]))
1188 elif (
1189 k in cls._decorated_fields
1190 or k not in op.get_serialized_fields()
1191 or k in ("outlets", "inlets")
1192 ):
1193 v = cls.deserialize(v)
1194 elif k == "on_failure_fail_dagrun":
1195 k = "_on_failure_fail_dagrun"
1196 elif k == "weight_rule":
1197 v = decode_priority_weight_strategy(v)
1198 # else use v as it is
1200 setattr(op, k, v)
1202 for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys():
1203 # TODO: refactor deserialization of BaseOperator and MappedOperaotr (split it out), then check
1204 # could go away.
1205 if not hasattr(op, k):
1206 setattr(op, k, None)
1208 # Set all the template_field to None that were not present in Serialized JSON
1209 for field in op.template_fields:
1210 if not hasattr(op, field):
1211 setattr(op, field, None)
1213 # Used to determine if an Operator is inherited from EmptyOperator
1214 setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))
1216 # Deserialize start_trigger
1217 serialized_start_trigger = encoded_op.get("start_trigger")
1218 if serialized_start_trigger:
1219 trigger_cls_name, trigger_kwargs = serialized_start_trigger
1220 trigger_cls = import_string(trigger_cls_name)
1221 start_trigger = trigger_cls(**trigger_kwargs)
1222 setattr(op, "start_trigger", start_trigger)
1223 else:
1224 setattr(op, "start_trigger", None)
1225 setattr(op, "next_method", encoded_op.get("next_method", None))
1227 @staticmethod
1228 def set_task_dag_references(task: Operator, dag: DAG) -> None:
1229 """Handle DAG references on an operator.
1231 The operator should have been mostly populated earlier by calling
1232 ``populate_operator``. This function further fixes object references
1233 that were not possible before the task's containing DAG is hydrated.
1234 """
1235 task.dag = dag
1237 for date_attr in ("start_date", "end_date"):
1238 if getattr(task, date_attr, None) is None:
1239 setattr(task, date_attr, getattr(dag, date_attr, None))
1241 if task.subdag is not None:
1242 task.subdag.parent_dag = dag
1244 # Dereference expand_input and op_kwargs_expand_input.
1245 for k in ("expand_input", "op_kwargs_expand_input"):
1246 if isinstance(kwargs_ref := getattr(task, k, None), _ExpandInputRef):
1247 setattr(task, k, kwargs_ref.deref(dag))
1249 for task_id in task.downstream_task_ids:
1250 # Bypass set_upstream etc here - it does more than we want
1251 dag.task_dict[task_id].upstream_task_ids.add(task.task_id)
1253 @classmethod
1254 def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
1255 """Deserializes an operator from a JSON object."""
1256 op: Operator
1257 if encoded_op.get("_is_mapped", False):
1258 # Most of these will be loaded later, these are just some stand-ins.
1259 op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()}
1260 try:
1261 operator_name = encoded_op["_operator_name"]
1262 except KeyError:
1263 operator_name = encoded_op["_task_type"]
1265 op = MappedOperator(
1266 operator_class=op_data,
1267 expand_input=EXPAND_INPUT_EMPTY,
1268 partial_kwargs={},
1269 task_id=encoded_op["task_id"],
1270 params={},
1271 deps=MappedOperator.deps_for(BaseOperator),
1272 operator_extra_links=BaseOperator.operator_extra_links,
1273 template_ext=BaseOperator.template_ext,
1274 template_fields=BaseOperator.template_fields,
1275 template_fields_renderers=BaseOperator.template_fields_renderers,
1276 ui_color=BaseOperator.ui_color,
1277 ui_fgcolor=BaseOperator.ui_fgcolor,
1278 is_empty=False,
1279 task_module=encoded_op["_task_module"],
1280 task_type=encoded_op["_task_type"],
1281 operator_name=operator_name,
1282 dag=None,
1283 task_group=None,
1284 start_date=None,
1285 end_date=None,
1286 disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
1287 expand_input_attr=encoded_op["_expand_input_attr"],
1288 start_trigger=None,
1289 next_method=None,
1290 )
1291 else:
1292 op = SerializedBaseOperator(task_id=encoded_op["task_id"])
1294 cls.populate_operator(op, encoded_op)
1295 return op
1297 @classmethod
1298 def detect_dependencies(cls, op: Operator) -> set[DagDependency]:
1299 """Detect between DAG dependencies for the operator."""
1301 def get_custom_dep() -> list[DagDependency]:
1302 """
1303 If custom dependency detector is configured, use it.
1305 TODO: Remove this logic in 3.0.
1306 """
1307 custom_dependency_detector_cls = conf.getimport("scheduler", "dependency_detector", fallback=None)
1308 if not (
1309 custom_dependency_detector_cls is None or custom_dependency_detector_cls is DependencyDetector
1310 ):
1311 warnings.warn(
1312 "Use of a custom dependency detector is deprecated. "
1313 "Support will be removed in a future release.",
1314 RemovedInAirflow3Warning,
1315 stacklevel=1,
1316 )
1317 dep = custom_dependency_detector_cls().detect_task_dependencies(op)
1318 if type(dep) is DagDependency:
1319 return [dep]
1320 return []
1322 dependency_detector = DependencyDetector()
1323 deps = set(dependency_detector.detect_task_dependencies(op))
1324 deps.update(get_custom_dep()) # todo: remove in 3.0
1325 return deps
1327 @classmethod
1328 def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
1329 if var is not None and op.has_dag() and attrname.endswith("_date"):
1330 # If this date is the same as the matching field in the dag, then
1331 # don't store it again at the task level.
1332 dag_date = getattr(op.dag, attrname, None)
1333 if var is dag_date or var == dag_date:
1334 return True
1335 return super()._is_excluded(var, attrname, op)
1337 @classmethod
1338 def _deserialize_deps(cls, deps: list[str]) -> set[BaseTIDep]:
1339 from airflow import plugins_manager
1341 plugins_manager.initialize_ti_deps_plugins()
1342 if plugins_manager.registered_ti_dep_classes is None:
1343 raise AirflowException("Can not load plugins")
1345 instances = set()
1346 for qn in set(deps):
1347 if (
1348 not qn.startswith("airflow.ti_deps.deps.")
1349 and qn not in plugins_manager.registered_ti_dep_classes
1350 ):
1351 raise SerializationError(
1352 f"Custom dep class {qn} not deserialized, please register it through plugins."
1353 )
1355 try:
1356 instances.add(import_string(qn)())
1357 except ImportError:
1358 log.warning("Error importing dep %r", qn, exc_info=True)
1359 return instances
1361 @classmethod
1362 def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str, BaseOperatorLink]:
1363 """
1364 Deserialize Operator Links if the Classes are registered in Airflow Plugins.
1366 Error is raised if the OperatorLink is not found in Plugins too.
1368 :param encoded_op_links: Serialized Operator Link
1369 :return: De-Serialized Operator Link
1370 """
1371 from airflow import plugins_manager
1373 plugins_manager.initialize_extra_operators_links_plugins()
1375 if plugins_manager.registered_operator_link_classes is None:
1376 raise AirflowException("Can't load plugins")
1377 op_predefined_extra_links = {}
1379 for _operator_links_source in encoded_op_links:
1380 # Get the key, value pair as Tuple where key is OperatorLink ClassName
1381 # and value is the dictionary containing the arguments passed to the OperatorLink
1382 #
1383 # Example of a single iteration:
1384 #
1385 # _operator_links_source =
1386 # {
1387 # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {
1388 # 'index': 0
1389 # }
1390 # },
1391 #
1392 # list(_operator_links_source.items()) =
1393 # [
1394 # (
1395 # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
1396 # {'index': 0}
1397 # )
1398 # ]
1399 #
1400 # list(_operator_links_source.items())[0] =
1401 # (
1402 # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
1403 # {
1404 # 'index': 0
1405 # }
1406 # )
1408 _operator_link_class_path, data = next(iter(_operator_links_source.items()))
1409 if _operator_link_class_path in get_operator_extra_links():
1410 single_op_link_class = import_string(_operator_link_class_path)
1411 elif _operator_link_class_path in plugins_manager.registered_operator_link_classes:
1412 single_op_link_class = plugins_manager.registered_operator_link_classes[
1413 _operator_link_class_path
1414 ]
1415 else:
1416 log.error("Operator Link class %r not registered", _operator_link_class_path)
1417 return {}
1419 op_link_parameters = {param: cls.deserialize(value) for param, value in data.items()}
1420 op_predefined_extra_link: BaseOperatorLink = single_op_link_class(**op_link_parameters)
1422 op_predefined_extra_links.update({op_predefined_extra_link.name: op_predefined_extra_link})
1424 return op_predefined_extra_links
1426 @classmethod
1427 def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOperatorLink]):
1428 """
1429 Serialize Operator Links.
1431 Store the import path of the OperatorLink and the arguments passed to it.
1432 For example:
1433 ``[{'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}]``
1435 :param operator_extra_links: Operator Link
1436 :return: Serialized Operator Link
1437 """
1438 serialize_operator_extra_links = []
1439 for operator_extra_link in operator_extra_links:
1440 op_link_arguments = {
1441 param: cls.serialize(value) for param, value in attrs.asdict(operator_extra_link).items()
1442 }
1444 module_path = (
1445 f"{operator_extra_link.__class__.__module__}.{operator_extra_link.__class__.__name__}"
1446 )
1447 serialize_operator_extra_links.append({module_path: op_link_arguments})
1449 return serialize_operator_extra_links
1451 @classmethod
1452 def serialize(cls, var: Any, *, strict: bool = False, use_pydantic_models: bool = False) -> Any:
1453 # the wonders of multiple inheritance BaseOperator defines an instance method
1454 return BaseSerialization.serialize(var=var, strict=strict, use_pydantic_models=use_pydantic_models)
1456 @classmethod
1457 def deserialize(cls, encoded_var: Any, use_pydantic_models: bool = False) -> Any:
1458 return BaseSerialization.deserialize(encoded_var=encoded_var, use_pydantic_models=use_pydantic_models)
1461class SerializedDAG(DAG, BaseSerialization):
1462 """
1463 A JSON serializable representation of DAG.
1465 A stringified DAG can only be used in the scope of scheduler and webserver, because fields
1466 that are not serializable, such as functions and customer defined classes, are casted to
1467 strings.
1469 Compared with SimpleDAG: SerializedDAG contains all information for webserver.
1470 Compared with DagPickle: DagPickle contains all information for worker, but some DAGs are
1471 not pickle-able. SerializedDAG works for all DAGs.
1472 """
1474 _decorated_fields = {"schedule_interval", "default_args", "_access_control"}
1476 @staticmethod
1477 def __get_constructor_defaults():
1478 param_to_attr = {
1479 "max_active_tasks": "_max_active_tasks",
1480 "dag_display_name": "_dag_display_property_value",
1481 "description": "_description",
1482 "default_view": "_default_view",
1483 "access_control": "_access_control",
1484 }
1485 return {
1486 param_to_attr.get(k, k): v.default
1487 for k, v in signature(DAG.__init__).parameters.items()
1488 if v.default is not v.empty
1489 }
1491 _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore
1492 del __get_constructor_defaults
1494 _json_schema = lazy_object_proxy.Proxy(load_dag_schema)
1496 @classmethod
1497 def serialize_dag(cls, dag: DAG) -> dict:
1498 """Serialize a DAG into a JSON object."""
1499 try:
1500 serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields)
1502 serialized_dag["_processor_dags_folder"] = DAGS_FOLDER
1504 # If schedule_interval is backed by timetable, serialize only
1505 # timetable; vice versa for a timetable backed by schedule_interval.
1506 if dag.timetable.summary == dag.schedule_interval:
1507 del serialized_dag["schedule_interval"]
1508 else:
1509 del serialized_dag["timetable"]
1511 serialized_dag["tasks"] = [cls.serialize(task) for _, task in dag.task_dict.items()]
1513 dag_deps = {
1514 dep
1515 for task in dag.task_dict.values()
1516 for dep in SerializedBaseOperator.detect_dependencies(task)
1517 }
1518 dag_deps.update(DependencyDetector.detect_dag_dependencies(dag))
1519 serialized_dag["dag_dependencies"] = [x.__dict__ for x in sorted(dag_deps)]
1520 serialized_dag["_task_group"] = TaskGroupSerialization.serialize_task_group(dag.task_group)
1522 # Edge info in the JSON exactly matches our internal structure
1523 serialized_dag["edge_info"] = dag.edge_info
1524 serialized_dag["params"] = cls._serialize_params_dict(dag.params)
1526 # has_on_*_callback are only stored if the value is True, as the default is False
1527 if dag.has_on_success_callback:
1528 serialized_dag["has_on_success_callback"] = True
1529 if dag.has_on_failure_callback:
1530 serialized_dag["has_on_failure_callback"] = True
1531 return serialized_dag
1532 except SerializationError:
1533 raise
1534 except Exception as e:
1535 raise SerializationError(f"Failed to serialize DAG {dag.dag_id!r}: {e}")
1537 @classmethod
1538 def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG:
1539 """Deserializes a DAG from a JSON object."""
1540 dag = SerializedDAG(dag_id=encoded_dag["_dag_id"])
1542 for k, v in encoded_dag.items():
1543 if k == "_downstream_task_ids":
1544 v = set(v)
1545 elif k == "tasks":
1546 SerializedBaseOperator._load_operator_extra_links = cls._load_operator_extra_links
1547 tasks = {}
1548 for obj in v:
1549 if obj.get(Encoding.TYPE) == DAT.OP:
1550 deser = SerializedBaseOperator.deserialize_operator(obj[Encoding.VAR])
1551 tasks[deser.task_id] = deser
1552 else: # todo: remove in Airflow 3.0 (backcompat for pre-2.10)
1553 tasks[obj["task_id"]] = SerializedBaseOperator.deserialize_operator(obj)
1554 k = "task_dict"
1555 v = tasks
1556 elif k == "timezone":
1557 v = cls._deserialize_timezone(v)
1558 elif k == "dagrun_timeout":
1559 v = cls._deserialize_timedelta(v)
1560 elif k.endswith("_date"):
1561 v = cls._deserialize_datetime(v)
1562 elif k == "edge_info":
1563 # Value structure matches exactly
1564 pass
1565 elif k == "timetable":
1566 v = decode_timetable(v)
1567 elif k == "weight_rule":
1568 v = decode_priority_weight_strategy(v)
1569 elif k in cls._decorated_fields:
1570 v = cls.deserialize(v)
1571 elif k == "params":
1572 v = cls._deserialize_params_dict(v)
1573 elif k == "dataset_triggers":
1574 v = cls.deserialize(v)
1575 # else use v as it is
1577 setattr(dag, k, v)
1579 # A DAG is always serialized with only one of schedule_interval and
1580 # timetable. This back-populates the other to ensure the two attributes
1581 # line up correctly on the DAG instance.
1582 if "timetable" in encoded_dag:
1583 dag.schedule_interval = dag.timetable.summary
1584 else:
1585 dag.timetable = create_timetable(dag.schedule_interval, dag.timezone)
1587 # Set _task_group
1588 if "_task_group" in encoded_dag:
1589 dag._task_group = TaskGroupSerialization.deserialize_task_group(
1590 encoded_dag["_task_group"],
1591 None,
1592 dag.task_dict,
1593 dag,
1594 )
1595 else:
1596 # This must be old data that had no task_group. Create a root TaskGroup and add
1597 # all tasks to it.
1598 dag._task_group = TaskGroup.create_root(dag)
1599 for task in dag.tasks:
1600 dag.task_group.add(task)
1602 # Set has_on_*_callbacks to True if they exist in Serialized blob as False is the default
1603 if "has_on_success_callback" in encoded_dag:
1604 dag.has_on_success_callback = True
1605 if "has_on_failure_callback" in encoded_dag:
1606 dag.has_on_failure_callback = True
1608 keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys()
1609 for k in keys_to_set_none:
1610 setattr(dag, k, None)
1612 for task in dag.task_dict.values():
1613 SerializedBaseOperator.set_task_dag_references(task, dag)
1615 return dag
1617 @classmethod
1618 def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
1619 # {} is explicitly different from None in the case of DAG-level access control
1620 # and as a result we need to preserve empty dicts through serialization for this field
1621 if attrname == "_access_control" and var is not None:
1622 return False
1623 return super()._is_excluded(var, attrname, op)
1625 @classmethod
1626 def to_dict(cls, var: Any) -> dict:
1627 """Stringifies DAGs and operators contained by var and returns a dict of var."""
1628 json_dict = {"__version": cls.SERIALIZER_VERSION, "dag": cls.serialize_dag(var)}
1630 # Validate Serialized DAG with Json Schema. Raises Error if it mismatches
1631 cls.validate_schema(json_dict)
1632 return json_dict
1634 @classmethod
1635 def from_dict(cls, serialized_obj: dict) -> SerializedDAG:
1636 """Deserializes a python dict in to the DAG and operators it contains."""
1637 ver = serialized_obj.get("__version", "<not present>")
1638 if ver != cls.SERIALIZER_VERSION:
1639 raise ValueError(f"Unsure how to deserialize version {ver!r}")
1640 return cls.deserialize_dag(serialized_obj["dag"])
1643class TaskGroupSerialization(BaseSerialization):
1644 """JSON serializable representation of a task group."""
1646 @classmethod
1647 def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
1648 """Serialize TaskGroup into a JSON object."""
1649 if not task_group:
1650 return None
1652 # task_group.xxx_ids needs to be sorted here, because task_group.xxx_ids is a set,
1653 # when converting set to list, the order is uncertain.
1654 # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
1655 encoded = {
1656 "_group_id": task_group._group_id,
1657 "prefix_group_id": task_group.prefix_group_id,
1658 "tooltip": task_group.tooltip,
1659 "ui_color": task_group.ui_color,
1660 "ui_fgcolor": task_group.ui_fgcolor,
1661 "children": {
1662 label: child.serialize_for_task_group() for label, child in task_group.children.items()
1663 },
1664 "upstream_group_ids": cls.serialize(sorted(task_group.upstream_group_ids)),
1665 "downstream_group_ids": cls.serialize(sorted(task_group.downstream_group_ids)),
1666 "upstream_task_ids": cls.serialize(sorted(task_group.upstream_task_ids)),
1667 "downstream_task_ids": cls.serialize(sorted(task_group.downstream_task_ids)),
1668 }
1670 if isinstance(task_group, MappedTaskGroup):
1671 expand_input = task_group._expand_input
1672 encoded["expand_input"] = {
1673 "type": get_map_type_key(expand_input),
1674 "value": cls.serialize(expand_input.value),
1675 }
1676 encoded["is_mapped"] = True
1678 return encoded
1680 @classmethod
1681 def deserialize_task_group(
1682 cls,
1683 encoded_group: dict[str, Any],
1684 parent_group: TaskGroup | None,
1685 task_dict: dict[str, Operator],
1686 dag: SerializedDAG,
1687 ) -> TaskGroup:
1688 """Deserializes a TaskGroup from a JSON object."""
1689 group_id = cls.deserialize(encoded_group["_group_id"])
1690 kwargs = {
1691 key: cls.deserialize(encoded_group[key])
1692 for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
1693 }
1695 if not encoded_group.get("is_mapped"):
1696 group = TaskGroup(group_id=group_id, parent_group=parent_group, dag=dag, **kwargs)
1697 else:
1698 xi = encoded_group["expand_input"]
1699 group = MappedTaskGroup(
1700 group_id=group_id,
1701 parent_group=parent_group,
1702 dag=dag,
1703 expand_input=_ExpandInputRef(xi["type"], cls.deserialize(xi["value"])).deref(dag),
1704 **kwargs,
1705 )
1707 def set_ref(task: Operator) -> Operator:
1708 task.task_group = weakref.proxy(group)
1709 return task
1711 group.children = {
1712 label: (
1713 set_ref(task_dict[val])
1714 if _type == DAT.OP
1715 else cls.deserialize_task_group(val, group, task_dict, dag=dag)
1716 )
1717 for label, (_type, val) in encoded_group["children"].items()
1718 }
1719 group.upstream_group_ids.update(cls.deserialize(encoded_group["upstream_group_ids"]))
1720 group.downstream_group_ids.update(cls.deserialize(encoded_group["downstream_group_ids"]))
1721 group.upstream_task_ids.update(cls.deserialize(encoded_group["upstream_task_ids"]))
1722 group.downstream_task_ids.update(cls.deserialize(encoded_group["downstream_task_ids"]))
1723 return group
1726@dataclass(frozen=True, order=True)
1727class DagDependency:
1728 """
1729 Dataclass for representing dependencies between DAGs.
1731 These are calculated during serialization and attached to serialized DAGs.
1732 """
1734 source: str
1735 target: str
1736 dependency_type: str
1737 dependency_id: str | None = None
1739 @property
1740 def node_id(self):
1741 """Node ID for graph rendering."""
1742 val = f"{self.dependency_type}"
1743 if self.dependency_type != "dataset":
1744 val += f":{self.source}:{self.target}"
1745 if self.dependency_id:
1746 val += f":{self.dependency_id}"
1747 return val
1750def _has_kubernetes() -> bool:
1751 global HAS_KUBERNETES
1752 if "HAS_KUBERNETES" in globals():
1753 return HAS_KUBERNETES
1755 # Loading kube modules is expensive, so delay it until the last moment
1757 try:
1758 from kubernetes.client import models as k8s
1760 try:
1761 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
1762 except ImportError:
1763 from airflow.kubernetes.pre_7_4_0_compatibility.pod_generator import ( # type: ignore[assignment]
1764 PodGenerator,
1765 )
1767 globals()["k8s"] = k8s
1768 globals()["PodGenerator"] = PodGenerator
1770 # isort: on
1771 HAS_KUBERNETES = True
1772 except ImportError:
1773 HAS_KUBERNETES = False
1774 return HAS_KUBERNETES