Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/triggers/base.py: 68%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

76 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import abc 

20import json 

21from collections.abc import AsyncIterator 

22from dataclasses import dataclass 

23from datetime import timedelta 

24from typing import Annotated, Any 

25 

26import structlog 

27from pydantic import ( 

28 BaseModel, 

29 Discriminator, 

30 JsonValue, 

31 Tag, 

32 model_serializer, 

33) 

34 

35from airflow.utils.log.logging_mixin import LoggingMixin 

36from airflow.utils.state import TaskInstanceState 

37 

38log = structlog.get_logger(logger_name=__name__) 

39 

40 

41@dataclass 

42class StartTriggerArgs: 

43 """Arguments required for start task execution from triggerer.""" 

44 

45 trigger_cls: str 

46 next_method: str 

47 trigger_kwargs: dict[str, Any] | None = None 

48 next_kwargs: dict[str, Any] | None = None 

49 timeout: timedelta | None = None 

50 

51 

52class BaseTrigger(abc.ABC, LoggingMixin): 

53 """ 

54 Base class for all triggers. 

55 

56 A trigger has two contexts it can exist in: 

57 

58 - Inside an Operator, when it's passed to TaskDeferred 

59 - Actively running in a trigger worker 

60 

61 We use the same class for both situations, and rely on all Trigger classes 

62 to be able to return the arguments (possible to encode with Airflow-JSON) that will 

63 let them be re-instantiated elsewhere. 

64 """ 

65 

66 def __init__(self, **kwargs): 

67 # these values are set by triggerer when preparing to run the instance 

68 # when run, they are injected into logger record. 

69 self.task_instance = None 

70 self.trigger_id = None 

71 

72 def _set_context(self, context): 

73 """Part of LoggingMixin and used mainly for configuration of task logging; not used for triggers.""" 

74 raise NotImplementedError 

75 

76 @abc.abstractmethod 

77 def serialize(self) -> tuple[str, dict[str, Any]]: 

78 """ 

79 Return the information needed to reconstruct this Trigger. 

80 

81 :return: Tuple of (class path, keyword arguments needed to re-instantiate). 

82 """ 

83 raise NotImplementedError("Triggers must implement serialize()") 

84 

85 @abc.abstractmethod 

86 async def run(self) -> AsyncIterator[TriggerEvent]: 

87 """ 

88 Run the trigger in an asynchronous context. 

89 

90 The trigger should yield an Event whenever it wants to fire off 

91 an event, and return None if it is finished. Single-event triggers 

92 should thus yield and then immediately return. 

93 

94 If it yields, it is likely that it will be resumed very quickly, 

95 but it may not be (e.g. if the workload is being moved to another 

96 triggerer process, or a multi-event trigger was being used for a 

97 single-event task defer). 

98 

99 In either case, Trigger classes should assume they will be persisted, 

100 and then rely on cleanup() being called when they are no longer needed. 

101 """ 

102 raise NotImplementedError("Triggers must implement run()") 

103 yield # To convince Mypy this is an async iterator. 

104 

105 async def cleanup(self) -> None: 

106 """ 

107 Cleanup the trigger. 

108 

109 Called when the trigger is no longer needed, and it's being removed 

110 from the active triggerer process. 

111 

112 This method follows the async/await pattern to allow to run the cleanup 

113 in triggerer main event loop. Exceptions raised by the cleanup method 

114 are ignored, so if you would like to be able to debug them and be notified 

115 that cleanup method failed, you should wrap your code with try/except block 

116 and handle it appropriately (in async-compatible way). 

117 """ 

118 

119 @staticmethod 

120 def repr(classpath: str, kwargs: dict[str, Any]): 

121 kwargs_str = ", ".join(f"{k}={v}" for k, v in kwargs.items()) 

122 return f"<{classpath} {kwargs_str}>" 

123 

124 def __repr__(self) -> str: 

125 classpath, kwargs = self.serialize() 

126 return self.repr(classpath, kwargs) 

127 

128 

129class BaseEventTrigger(BaseTrigger): 

130 """ 

131 Base class for triggers used to schedule DAGs based on external events. 

132 

133 ``BaseEventTrigger`` is a subclass of ``BaseTrigger`` designed to identify triggers compatible with 

134 event-driven scheduling. 

135 """ 

136 

137 @staticmethod 

138 def hash(classpath: str, kwargs: dict[str, Any]) -> int: 

139 """ 

140 Return the hash of the trigger classpath and kwargs. This is used to uniquely identify a trigger. 

141 

142 We do not want to have this logic in ``BaseTrigger`` because, when used to defer tasks, 2 triggers 

143 can have the same classpath and kwargs. This is not true for event driven scheduling. 

144 """ 

145 from airflow.serialization.serialized_objects import BaseSerialization 

146 

147 return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8"))) 

148 

149 

150class TriggerEvent(BaseModel): 

151 """ 

152 Something that a trigger can fire when its conditions are met. 

153 

154 Events must have a uniquely identifying value that would be the same 

155 wherever the trigger is run; this is to ensure that if the same trigger 

156 is being run in two locations (for HA reasons) that we can deduplicate its 

157 events. 

158 """ 

159 

160 payload: Any = None 

161 """ 

162 The payload for the event to send back to the task. 

163 

164 Must be natively JSON-serializable, or registered with the airflow serialization code. 

165 """ 

166 

167 def __init__(self, payload, **kwargs): 

168 super().__init__(payload=payload, **kwargs) 

169 

170 def __repr__(self) -> str: 

171 return f"TriggerEvent<{self.payload!r}>" 

172 

173 

174class BaseTaskEndEvent(TriggerEvent): 

175 """ 

176 Base event class to end the task without resuming on worker. 

177 

178 :meta private: 

179 """ 

180 

181 task_instance_state: TaskInstanceState 

182 xcoms: dict[str, JsonValue] | None = None 

183 

184 def __init__(self, *, xcoms: dict[str, JsonValue] | None = None, **kwargs) -> None: 

185 """ 

186 Initialize the class with the specified parameters. 

187 

188 :param xcoms: A dictionary of XComs or None. 

189 :param kwargs: Additional keyword arguments. 

190 """ 

191 if "payload" in kwargs: 

192 raise ValueError("Param 'payload' not supported for this class.") 

193 # Yes this is _odd_. It's to support both constructor from users of 

194 # `TaskSuccessEvent(some_xcom_value)` and deserialization by pydantic. 

195 state = kwargs.pop("task_instance_state", self.__pydantic_fields__["task_instance_state"].default) 

196 super().__init__(payload=str(state), task_instance_state=state, **kwargs) 

197 self.xcoms = xcoms 

198 

199 @model_serializer 

200 def ser_model(self) -> dict[str, Any]: 

201 # We need to customize the serialized schema so it works for the custom constructor we have to keep 

202 # the interface to this class "nice" 

203 return {"task_instance_state": self.task_instance_state, "xcoms": self.xcoms} 

204 

205 

206class TaskSuccessEvent(BaseTaskEndEvent): 

207 """Yield this event in order to end the task successfully.""" 

208 

209 task_instance_state: TaskInstanceState = TaskInstanceState.SUCCESS 

210 

211 

212class TaskFailedEvent(BaseTaskEndEvent): 

213 """Yield this event in order to end the task with failure.""" 

214 

215 task_instance_state: TaskInstanceState = TaskInstanceState.FAILED 

216 

217 

218class TaskSkippedEvent(BaseTaskEndEvent): 

219 """Yield this event in order to end the task with status 'skipped'.""" 

220 

221 task_instance_state: TaskInstanceState = TaskInstanceState.SKIPPED 

222 

223 

224def trigger_event_discriminator(v): 

225 if isinstance(v, dict): 

226 return v.get("task_instance_state", "_event_") 

227 if isinstance(v, TriggerEvent): 

228 return getattr(v, "task_instance_state", "_event_") 

229 

230 

231DiscrimatedTriggerEvent = Annotated[ 

232 Annotated[TriggerEvent, Tag("_event_")] 

233 | Annotated[TaskSuccessEvent, Tag(TaskInstanceState.SUCCESS)] 

234 | Annotated[TaskFailedEvent, Tag(TaskInstanceState.FAILED)] 

235 | Annotated[TaskSkippedEvent, Tag(TaskInstanceState.SKIPPED)], 

236 Discriminator(trigger_event_discriminator), 

237]