1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import sys
21from collections import deque
22from types import ModuleType
23from typing import TYPE_CHECKING, Any, Generic, TypeVar
24
25from airflow.sdk.definitions.dag import DAG
26from airflow.sdk.definitions.taskgroup import TaskGroup
27
28if TYPE_CHECKING:
29 from airflow.sdk.definitions.context import Context
30
31T = TypeVar("T")
32
33__all__ = ["DagContext", "TaskGroupContext"]
34
35# This is a global variable that stores the current Task context.
36# It is used to push the Context dictionary when Task starts execution
37# and it is used to retrieve the current context in PythonOperator or Taskflow API via
38# the `get_current_context` function.
39_CURRENT_CONTEXT: list[Context] = []
40
41
42def _get_current_context() -> Context:
43 if not _CURRENT_CONTEXT:
44 raise RuntimeError(
45 "Current context was requested but no context was found! Are you running within an Airflow task?"
46 )
47 return _CURRENT_CONTEXT[-1]
48
49
50# In order to add a `@classproperty`-like thing we need to define a property on a metaclass.
51class ContextStackMeta(type):
52 _context: deque
53
54 # TODO: Task-SDK:
55 # share_parent_context can go away once the Dag and TaskContext manager in airflow.models are removed and
56 # everything uses sdk fully for definition/parsing
57 def __new__(cls, name, bases, namespace, share_parent_context: bool = False, **kwargs: Any):
58 if not share_parent_context:
59 namespace["_context"] = deque()
60
61 new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
62
63 return new_cls
64
65 @property
66 def active(self) -> bool:
67 """The active property says if any object is currently in scope."""
68 return bool(self._context)
69
70
71class ContextStack(Generic[T], metaclass=ContextStackMeta):
72 _context: deque[T]
73
74 @classmethod
75 def push(cls, obj: T):
76 cls._context.appendleft(obj)
77
78 @classmethod
79 def pop(cls) -> T | None:
80 return cls._context.popleft()
81
82 @classmethod
83 def get_current(cls) -> T | None:
84 try:
85 return cls._context[0]
86 except IndexError:
87 return None
88
89
90class DagContext(ContextStack[DAG]):
91 """
92 Dag context is used to keep the current Dag when Dag is used as ContextManager.
93
94 You can use Dag as context:
95
96 .. code-block:: python
97
98 with DAG(
99 dag_id="example_dag",
100 default_args=default_args,
101 schedule="0 0 * * *",
102 dagrun_timeout=timedelta(minutes=60),
103 ) as dag:
104 ...
105
106 If you do this the context stores the Dag and whenever new task is created, it will use
107 such stored Dag as the parent Dag.
108
109 """
110
111 autoregistered_dags: set[tuple[DAG, ModuleType]] = set()
112 current_autoregister_module_name: str | None = None
113
114 @classmethod
115 def pop(cls) -> DAG | None:
116 dag = super().pop()
117 # In a few cases around serialization we explicitly push None in to the stack
118 if cls.current_autoregister_module_name is not None and dag and getattr(dag, "auto_register", True):
119 mod = sys.modules[cls.current_autoregister_module_name]
120 cls.autoregistered_dags.add((dag, mod))
121 return dag
122
123 @classmethod
124 def get_current_dag(cls) -> DAG | None:
125 return cls.get_current()
126
127
128class TaskGroupContext(ContextStack[TaskGroup]):
129 """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
130
131 @classmethod
132 def get_current(cls, dag: DAG | None = None) -> TaskGroup | None:
133 if current := super().get_current():
134 return current
135 if dag := dag or DagContext.get_current():
136 # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag.
137 return dag.task_group
138 return None