1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19from datetime import datetime
20from typing import TYPE_CHECKING, Any, Iterable, Optional
21
22from typing_extensions import Annotated
23
24from airflow.models import Operator
25from airflow.models.baseoperator import BaseOperator
26from airflow.models.taskinstance import TaskInstance
27from airflow.serialization.pydantic.dag import DagModelPydantic
28from airflow.serialization.pydantic.dag_run import DagRunPydantic
29from airflow.utils.log.logging_mixin import LoggingMixin
30from airflow.utils.net import get_hostname
31from airflow.utils.pydantic import (
32 BaseModel as BaseModelPydantic,
33 ConfigDict,
34 PlainSerializer,
35 PlainValidator,
36 is_pydantic_2_installed,
37)
38from airflow.utils.xcom import XCOM_RETURN_KEY
39
40if TYPE_CHECKING:
41 import pendulum
42 from sqlalchemy.orm import Session
43
44 from airflow.models.dagrun import DagRun
45 from airflow.utils.context import Context
46 from airflow.utils.pydantic import ValidationInfo
47 from airflow.utils.state import DagRunState
48
49
50def serialize_operator(x: Operator | None) -> dict | None:
51 if x:
52 from airflow.serialization.serialized_objects import SerializedBaseOperator
53
54 return SerializedBaseOperator.serialize_operator(x)
55 return None
56
57
58def validated_operator(x: dict[str, Any] | Operator, _info: ValidationInfo) -> Any:
59 from airflow.models.baseoperator import BaseOperator
60 from airflow.models.mappedoperator import MappedOperator
61 from airflow.serialization.serialized_objects import SerializedBaseOperator
62
63 if isinstance(x, BaseOperator) or isinstance(x, MappedOperator) or x is None:
64 return x
65 return SerializedBaseOperator.deserialize_operator(x)
66
67
68PydanticOperator = Annotated[
69 Operator,
70 PlainValidator(validated_operator),
71 PlainSerializer(serialize_operator, return_type=dict),
72]
73
74
75class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
76 """Serializable representation of the TaskInstance ORM SqlAlchemyModel used by internal API."""
77
78 task_id: str
79 dag_id: str
80 run_id: str
81 map_index: int
82 start_date: Optional[datetime]
83 end_date: Optional[datetime]
84 execution_date: Optional[datetime]
85 duration: Optional[float]
86 state: Optional[str]
87 try_number: int
88 max_tries: int
89 hostname: str
90 unixname: str
91 job_id: Optional[int]
92 pool: str
93 pool_slots: int
94 queue: str
95 priority_weight: Optional[int]
96 operator: str
97 custom_operator_name: Optional[str]
98 queued_dttm: Optional[datetime]
99 queued_by_job_id: Optional[int]
100 pid: Optional[int]
101 executor: Optional[str]
102 executor_config: Any
103 updated_at: Optional[datetime]
104 rendered_map_index: Optional[str]
105 external_executor_id: Optional[str]
106 trigger_id: Optional[int]
107 trigger_timeout: Optional[datetime]
108 next_method: Optional[str]
109 next_kwargs: Optional[dict]
110 run_as_user: Optional[str]
111 task: Optional[PydanticOperator]
112 test_mode: bool
113 dag_run: Optional[DagRunPydantic]
114 dag_model: Optional[DagModelPydantic]
115 raw: Optional[bool]
116 is_trigger_log_context: Optional[bool]
117 model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)
118
119 @property
120 def _logger_name(self):
121 return "airflow.task"
122
123 def clear_xcom_data(self, session: Session | None = None):
124 TaskInstance._clear_xcom_data(ti=self, session=session)
125
126 def set_state(self, state, session: Session | None = None) -> bool:
127 return TaskInstance._set_state(ti=self, state=state, session=session)
128
129 def _run_execute_callback(self, context, task):
130 TaskInstance._run_execute_callback(self=self, context=context, task=task) # type: ignore[arg-type]
131
132 def render_templates(self, context: Context | None = None, jinja_env=None):
133 return TaskInstance.render_templates(self=self, context=context, jinja_env=jinja_env) # type: ignore[arg-type]
134
135 def init_run_context(self, raw: bool = False) -> None:
136 """Set the log context."""
137 self.raw = raw
138 self._set_context(self)
139
140 def xcom_pull(
141 self,
142 task_ids: str | Iterable[str] | None = None,
143 dag_id: str | None = None,
144 key: str = XCOM_RETURN_KEY,
145 include_prior_dates: bool = False,
146 *,
147 map_indexes: int | Iterable[int] | None = None,
148 default: Any = None,
149 ) -> Any:
150 """
151 Pull an XCom value for this task instance.
152
153 TODO: make it works for AIP-44
154 :param task_ids: task id or list of task ids, if None, the task_id of the current task is used
155 :param dag_id: dag id, if None, the dag_id of the current task is used
156 :param key: the key to identify the XCom value
157 :param include_prior_dates: whether to include prior execution dates
158 :param map_indexes: map index or list of map indexes, if None, the map_index of the current task
159 is used
160 :param default: the default value to return if the XCom value does not exist
161 :return: Xcom value
162 """
163 return None
164
165 def xcom_push(
166 self,
167 key: str,
168 value: Any,
169 execution_date: datetime | None = None,
170 session: Session | None = None,
171 ) -> None:
172 """
173 Push an XCom value for this task instance.
174
175 TODO: make it works for AIP-44
176 :param key: the key to identify the XCom value
177 :param value: the value of the XCom
178 :param execution_date: the execution date to push the XCom for
179 """
180 pass
181
182 def get_dagrun(self, session: Session | None = None) -> DagRunPydantic:
183 """
184 Return the DagRun for this TaskInstance.
185
186 :param session: SQLAlchemy ORM Session
187
188 :return: Pydantic serialized version of DagRun
189 """
190 return TaskInstance._get_dagrun(dag_id=self.dag_id, run_id=self.run_id, session=session)
191
192 def _execute_task(self, context, task_orig):
193 """
194 Execute Task (optionally with a Timeout) and push Xcom results.
195
196 :param context: Jinja2 context
197 :param task_orig: origin task
198 """
199 from airflow.models.taskinstance import _execute_task
200
201 return _execute_task(task_instance=self, context=context, task_orig=task_orig)
202
203 def refresh_from_db(self, session: Session | None = None, lock_for_update: bool = False) -> None:
204 """
205 Refresh the task instance from the database based on the primary key.
206
207 :param session: SQLAlchemy ORM Session
208 :param lock_for_update: if True, indicates that the database should
209 lock the TaskInstance (issuing a FOR UPDATE clause) until the
210 session is committed.
211 """
212 from airflow.models.taskinstance import _refresh_from_db
213
214 _refresh_from_db(task_instance=self, session=session, lock_for_update=lock_for_update)
215
216 def set_duration(self) -> None:
217 """Set task instance duration."""
218 from airflow.models.taskinstance import _set_duration
219
220 _set_duration(task_instance=self)
221
222 @property
223 def stats_tags(self) -> dict[str, str]:
224 """Return task instance tags."""
225 from airflow.models.taskinstance import _stats_tags
226
227 return _stats_tags(task_instance=self)
228
229 def clear_next_method_args(self) -> None:
230 """Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them."""
231 from airflow.models.taskinstance import _clear_next_method_args
232
233 _clear_next_method_args(task_instance=self)
234
235 def get_template_context(
236 self,
237 session: Session | None = None,
238 ignore_param_exceptions: bool = True,
239 ) -> Context:
240 """
241 Return TI Context.
242
243 :param session: SQLAlchemy ORM Session
244 :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
245 """
246 from airflow.models.taskinstance import _get_template_context
247
248 return _get_template_context(
249 task_instance=self,
250 session=session,
251 ignore_param_exceptions=ignore_param_exceptions,
252 )
253
254 def is_eligible_to_retry(self):
255 """Is task instance is eligible for retry."""
256 from airflow.models.taskinstance import _is_eligible_to_retry
257
258 return _is_eligible_to_retry(task_instance=self)
259
260 def handle_failure(
261 self,
262 error: None | str | Exception | KeyboardInterrupt,
263 test_mode: bool | None = None,
264 context: Context | None = None,
265 force_fail: bool = False,
266 session: Session | None = None,
267 ) -> None:
268 """
269 Handle Failure for a task instance.
270
271 :param error: if specified, log the specific exception if thrown
272 :param session: SQLAlchemy ORM Session
273 :param test_mode: doesn't record success or failure in the DB if True
274 :param context: Jinja2 context
275 :param force_fail: if True, task does not retry
276 """
277 from airflow.models.taskinstance import _handle_failure
278
279 if TYPE_CHECKING:
280 assert self.task
281 assert self.task.dag
282 try:
283 fail_stop = self.task.dag.fail_stop
284 except Exception:
285 fail_stop = False
286 _handle_failure(
287 task_instance=self,
288 error=error,
289 session=session,
290 test_mode=test_mode,
291 context=context,
292 force_fail=force_fail,
293 fail_stop=fail_stop,
294 )
295
296 def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None:
297 """
298 Copy common attributes from the given task.
299
300 :param task: The task object to copy from
301 :param pool_override: Use the pool_override instead of task's pool
302 """
303 from airflow.models.taskinstance import _refresh_from_task
304
305 _refresh_from_task(task_instance=self, task=task, pool_override=pool_override)
306
307 def get_previous_dagrun(
308 self,
309 state: DagRunState | None = None,
310 session: Session | None = None,
311 ) -> DagRun | None:
312 """
313 Return the DagRun that ran before this task instance's DagRun.
314
315 :param state: If passed, it only take into account instances of a specific state.
316 :param session: SQLAlchemy ORM Session.
317 """
318 from airflow.models.taskinstance import _get_previous_dagrun
319
320 return _get_previous_dagrun(task_instance=self, state=state, session=session)
321
322 def get_previous_execution_date(
323 self,
324 state: DagRunState | None = None,
325 session: Session | None = None,
326 ) -> pendulum.DateTime | None:
327 """
328 Return the execution date from property previous_ti_success.
329
330 :param state: If passed, it only take into account instances of a specific state.
331 :param session: SQLAlchemy ORM Session
332 """
333 from airflow.models.taskinstance import _get_previous_execution_date
334
335 return _get_previous_execution_date(task_instance=self, state=state, session=session)
336
337 def email_alert(self, exception, task: BaseOperator) -> None:
338 """
339 Send alert email with exception information.
340
341 :param exception: the exception
342 :param task: task related to the exception
343 """
344 from airflow.models.taskinstance import _email_alert
345
346 _email_alert(task_instance=self, exception=exception, task=task)
347
348 def get_email_subject_content(
349 self, exception: BaseException, task: BaseOperator | None = None
350 ) -> tuple[str, str, str]:
351 """
352 Get the email subject content for exceptions.
353
354 :param exception: the exception sent in the email
355 :param task:
356 """
357 from airflow.models.taskinstance import _get_email_subject_content
358
359 return _get_email_subject_content(task_instance=self, exception=exception, task=task)
360
361 def get_previous_ti(
362 self,
363 state: DagRunState | None = None,
364 session: Session | None = None,
365 ) -> TaskInstance | TaskInstancePydantic | None:
366 """
367 Return the task instance for the task that ran before this task instance.
368
369 :param session: SQLAlchemy ORM Session
370 :param state: If passed, it only take into account instances of a specific state.
371 """
372 from airflow.models.taskinstance import _get_previous_ti
373
374 return _get_previous_ti(task_instance=self, state=state, session=session)
375
376 def check_and_change_state_before_execution(
377 self,
378 verbose: bool = True,
379 ignore_all_deps: bool = False,
380 ignore_depends_on_past: bool = False,
381 wait_for_past_depends_before_skipping: bool = False,
382 ignore_task_deps: bool = False,
383 ignore_ti_state: bool = False,
384 mark_success: bool = False,
385 test_mode: bool = False,
386 job_id: str | None = None,
387 pool: str | None = None,
388 external_executor_id: str | None = None,
389 session: Session | None = None,
390 ) -> bool:
391 return TaskInstance._check_and_change_state_before_execution(
392 task_instance=self,
393 verbose=verbose,
394 ignore_all_deps=ignore_all_deps,
395 ignore_depends_on_past=ignore_depends_on_past,
396 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
397 ignore_task_deps=ignore_task_deps,
398 ignore_ti_state=ignore_ti_state,
399 mark_success=mark_success,
400 test_mode=test_mode,
401 hostname=get_hostname(),
402 job_id=job_id,
403 pool=pool,
404 external_executor_id=external_executor_id,
405 session=session,
406 )
407
408 def schedule_downstream_tasks(self, session: Session | None = None, max_tis_per_query: int | None = None):
409 """
410 Schedule downstream tasks of this task instance.
411
412 :meta: private
413 """
414 return TaskInstance._schedule_downstream_tasks(
415 ti=self, session=session, max_tis_per_query=max_tis_per_query
416 )
417
418 def command_as_list(
419 self,
420 mark_success: bool = False,
421 ignore_all_deps: bool = False,
422 ignore_task_deps: bool = False,
423 ignore_depends_on_past: bool = False,
424 wait_for_past_depends_before_skipping: bool = False,
425 ignore_ti_state: bool = False,
426 local: bool = False,
427 pickle_id: int | None = None,
428 raw: bool = False,
429 job_id: str | None = None,
430 pool: str | None = None,
431 cfg_path: str | None = None,
432 ) -> list[str]:
433 """
434 Return a command that can be executed anywhere where airflow is installed.
435
436 This command is part of the message sent to executors by the orchestrator.
437 """
438 return TaskInstance._command_as_list(
439 ti=self,
440 mark_success=mark_success,
441 ignore_all_deps=ignore_all_deps,
442 ignore_task_deps=ignore_task_deps,
443 ignore_depends_on_past=ignore_depends_on_past,
444 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
445 ignore_ti_state=ignore_ti_state,
446 local=local,
447 pickle_id=pickle_id,
448 raw=raw,
449 job_id=job_id,
450 pool=pool,
451 cfg_path=cfg_path,
452 )
453
454
455if is_pydantic_2_installed():
456 TaskInstancePydantic.model_rebuild()