Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/contextmanager.py: 63%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

62 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import sys 

21from collections import deque 

22from types import ModuleType 

23from typing import TYPE_CHECKING, Any, Generic, TypeVar 

24 

25from airflow.sdk.definitions.dag import DAG 

26from airflow.sdk.definitions.taskgroup import TaskGroup 

27 

28if TYPE_CHECKING: 

29 from airflow.sdk.definitions.context import Context 

30 

31T = TypeVar("T") 

32 

33__all__ = ["DagContext", "TaskGroupContext"] 

34 

35# This is a global variable that stores the current Task context. 

36# It is used to push the Context dictionary when Task starts execution 

37# and it is used to retrieve the current context in PythonOperator or Taskflow API via 

38# the `get_current_context` function. 

39_CURRENT_CONTEXT: list[Context] = [] 

40 

41 

42def _get_current_context() -> Context: 

43 if not _CURRENT_CONTEXT: 

44 raise RuntimeError( 

45 "Current context was requested but no context was found! Are you running within an Airflow task?" 

46 ) 

47 return _CURRENT_CONTEXT[-1] 

48 

49 

50# In order to add a `@classproperty`-like thing we need to define a property on a metaclass. 

51class ContextStackMeta(type): 

52 _context: deque 

53 

54 # TODO: Task-SDK: 

55 # share_parent_context can go away once the Dag and TaskContext manager in airflow.models are removed and 

56 # everything uses sdk fully for definition/parsing 

57 def __new__(cls, name, bases, namespace, share_parent_context: bool = False, **kwargs: Any): 

58 if not share_parent_context: 

59 namespace["_context"] = deque() 

60 

61 new_cls = super().__new__(cls, name, bases, namespace, **kwargs) 

62 

63 return new_cls 

64 

65 @property 

66 def active(self) -> bool: 

67 """The active property says if any object is currently in scope.""" 

68 return bool(self._context) 

69 

70 

71class ContextStack(Generic[T], metaclass=ContextStackMeta): 

72 _context: deque[T] 

73 

74 @classmethod 

75 def push(cls, obj: T): 

76 cls._context.appendleft(obj) 

77 

78 @classmethod 

79 def pop(cls) -> T | None: 

80 return cls._context.popleft() 

81 

82 @classmethod 

83 def get_current(cls) -> T | None: 

84 try: 

85 return cls._context[0] 

86 except IndexError: 

87 return None 

88 

89 

90class DagContext(ContextStack[DAG]): 

91 """ 

92 Dag context is used to keep the current Dag when Dag is used as ContextManager. 

93 

94 You can use Dag as context: 

95 

96 .. code-block:: python 

97 

98 with DAG( 

99 dag_id="example_dag", 

100 default_args=default_args, 

101 schedule="0 0 * * *", 

102 dagrun_timeout=timedelta(minutes=60), 

103 ) as dag: 

104 ... 

105 

106 If you do this the context stores the Dag and whenever new task is created, it will use 

107 such stored Dag as the parent Dag. 

108 

109 """ 

110 

111 autoregistered_dags: set[tuple[DAG, ModuleType]] = set() 

112 current_autoregister_module_name: str | None = None 

113 

114 @classmethod 

115 def pop(cls) -> DAG | None: 

116 dag = super().pop() 

117 # In a few cases around serialization we explicitly push None in to the stack 

118 if cls.current_autoregister_module_name is not None and dag and getattr(dag, "auto_register", True): 

119 mod = sys.modules[cls.current_autoregister_module_name] 

120 cls.autoregistered_dags.add((dag, mod)) 

121 return dag 

122 

123 @classmethod 

124 def get_current_dag(cls) -> DAG | None: 

125 return cls.get_current() 

126 

127 

128class TaskGroupContext(ContextStack[TaskGroup]): 

129 """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" 

130 

131 @classmethod 

132 def get_current(cls, dag: DAG | None = None) -> TaskGroup | None: 

133 if current := super().get_current(): 

134 return current 

135 if dag := dag or DagContext.get_current(): 

136 # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag. 

137 return dag.task_group 

138 return None