Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/serialization/pydantic/taskinstance.py: 57%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19from datetime import datetime
20from typing import TYPE_CHECKING, Any, Iterable, Optional
22from typing_extensions import Annotated
24from airflow.models import Operator
25from airflow.models.baseoperator import BaseOperator
26from airflow.models.taskinstance import TaskInstance
27from airflow.serialization.pydantic.dag import DagModelPydantic
28from airflow.serialization.pydantic.dag_run import DagRunPydantic
29from airflow.utils.log.logging_mixin import LoggingMixin
30from airflow.utils.net import get_hostname
31from airflow.utils.pydantic import (
32 BaseModel as BaseModelPydantic,
33 ConfigDict,
34 PlainSerializer,
35 PlainValidator,
36 is_pydantic_2_installed,
37)
38from airflow.utils.xcom import XCOM_RETURN_KEY
40if TYPE_CHECKING:
41 import pendulum
42 from sqlalchemy.orm import Session
44 from airflow.models.dagrun import DagRun
45 from airflow.utils.context import Context
46 from airflow.utils.pydantic import ValidationInfo
47 from airflow.utils.state import DagRunState
50def serialize_operator(x: Operator | None) -> dict | None:
51 if x:
52 from airflow.serialization.serialized_objects import SerializedBaseOperator
54 return SerializedBaseOperator.serialize_operator(x)
55 return None
58def validated_operator(x: dict[str, Any] | Operator, _info: ValidationInfo) -> Any:
59 from airflow.models.baseoperator import BaseOperator
60 from airflow.models.mappedoperator import MappedOperator
61 from airflow.serialization.serialized_objects import SerializedBaseOperator
63 if isinstance(x, BaseOperator) or isinstance(x, MappedOperator) or x is None:
64 return x
65 return SerializedBaseOperator.deserialize_operator(x)
68PydanticOperator = Annotated[
69 Operator,
70 PlainValidator(validated_operator),
71 PlainSerializer(serialize_operator, return_type=dict),
72]
75class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
76 """Serializable representation of the TaskInstance ORM SqlAlchemyModel used by internal API."""
78 task_id: str
79 dag_id: str
80 run_id: str
81 map_index: int
82 start_date: Optional[datetime]
83 end_date: Optional[datetime]
84 execution_date: Optional[datetime]
85 duration: Optional[float]
86 state: Optional[str]
87 try_number: int
88 max_tries: int
89 hostname: str
90 unixname: str
91 job_id: Optional[int]
92 pool: str
93 pool_slots: int
94 queue: str
95 priority_weight: Optional[int]
96 operator: str
97 custom_operator_name: Optional[str]
98 queued_dttm: Optional[datetime]
99 queued_by_job_id: Optional[int]
100 pid: Optional[int]
101 executor: Optional[str]
102 executor_config: Any
103 updated_at: Optional[datetime]
104 rendered_map_index: Optional[str]
105 external_executor_id: Optional[str]
106 trigger_id: Optional[int]
107 trigger_timeout: Optional[datetime]
108 next_method: Optional[str]
109 next_kwargs: Optional[dict]
110 run_as_user: Optional[str]
111 task: Optional[PydanticOperator]
112 test_mode: bool
113 dag_run: Optional[DagRunPydantic]
114 dag_model: Optional[DagModelPydantic]
115 raw: Optional[bool]
116 is_trigger_log_context: Optional[bool]
117 model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)
119 @property
120 def _logger_name(self):
121 return "airflow.task"
123 def clear_xcom_data(self, session: Session | None = None):
124 TaskInstance._clear_xcom_data(ti=self, session=session)
126 def set_state(self, state, session: Session | None = None) -> bool:
127 return TaskInstance._set_state(ti=self, state=state, session=session)
129 def _run_execute_callback(self, context, task):
130 TaskInstance._run_execute_callback(self=self, context=context, task=task) # type: ignore[arg-type]
132 def render_templates(self, context: Context | None = None, jinja_env=None):
133 return TaskInstance.render_templates(self=self, context=context, jinja_env=jinja_env) # type: ignore[arg-type]
135 def init_run_context(self, raw: bool = False) -> None:
136 """Set the log context."""
137 self.raw = raw
138 self._set_context(self)
140 def xcom_pull(
141 self,
142 task_ids: str | Iterable[str] | None = None,
143 dag_id: str | None = None,
144 key: str = XCOM_RETURN_KEY,
145 include_prior_dates: bool = False,
146 *,
147 map_indexes: int | Iterable[int] | None = None,
148 default: Any = None,
149 ) -> Any:
150 """
151 Pull an XCom value for this task instance.
153 TODO: make it works for AIP-44
154 :param task_ids: task id or list of task ids, if None, the task_id of the current task is used
155 :param dag_id: dag id, if None, the dag_id of the current task is used
156 :param key: the key to identify the XCom value
157 :param include_prior_dates: whether to include prior execution dates
158 :param map_indexes: map index or list of map indexes, if None, the map_index of the current task
159 is used
160 :param default: the default value to return if the XCom value does not exist
161 :return: Xcom value
162 """
163 return None
165 def xcom_push(
166 self,
167 key: str,
168 value: Any,
169 execution_date: datetime | None = None,
170 session: Session | None = None,
171 ) -> None:
172 """
173 Push an XCom value for this task instance.
175 TODO: make it works for AIP-44
176 :param key: the key to identify the XCom value
177 :param value: the value of the XCom
178 :param execution_date: the execution date to push the XCom for
179 """
180 pass
182 def get_dagrun(self, session: Session | None = None) -> DagRunPydantic:
183 """
184 Return the DagRun for this TaskInstance.
186 :param session: SQLAlchemy ORM Session
188 :return: Pydantic serialized version of DagRun
189 """
190 return TaskInstance._get_dagrun(dag_id=self.dag_id, run_id=self.run_id, session=session)
192 def _execute_task(self, context, task_orig):
193 """
194 Execute Task (optionally with a Timeout) and push Xcom results.
196 :param context: Jinja2 context
197 :param task_orig: origin task
198 """
199 from airflow.models.taskinstance import _execute_task
201 return _execute_task(task_instance=self, context=context, task_orig=task_orig)
203 def refresh_from_db(self, session: Session | None = None, lock_for_update: bool = False) -> None:
204 """
205 Refresh the task instance from the database based on the primary key.
207 :param session: SQLAlchemy ORM Session
208 :param lock_for_update: if True, indicates that the database should
209 lock the TaskInstance (issuing a FOR UPDATE clause) until the
210 session is committed.
211 """
212 from airflow.models.taskinstance import _refresh_from_db
214 _refresh_from_db(task_instance=self, session=session, lock_for_update=lock_for_update)
216 def set_duration(self) -> None:
217 """Set task instance duration."""
218 from airflow.models.taskinstance import _set_duration
220 _set_duration(task_instance=self)
222 @property
223 def stats_tags(self) -> dict[str, str]:
224 """Return task instance tags."""
225 from airflow.models.taskinstance import _stats_tags
227 return _stats_tags(task_instance=self)
229 def clear_next_method_args(self) -> None:
230 """Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them."""
231 from airflow.models.taskinstance import _clear_next_method_args
233 _clear_next_method_args(task_instance=self)
235 def get_template_context(
236 self,
237 session: Session | None = None,
238 ignore_param_exceptions: bool = True,
239 ) -> Context:
240 """
241 Return TI Context.
243 :param session: SQLAlchemy ORM Session
244 :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
245 """
246 from airflow.models.taskinstance import _get_template_context
248 return _get_template_context(
249 task_instance=self,
250 session=session,
251 ignore_param_exceptions=ignore_param_exceptions,
252 )
254 def is_eligible_to_retry(self):
255 """Is task instance is eligible for retry."""
256 from airflow.models.taskinstance import _is_eligible_to_retry
258 return _is_eligible_to_retry(task_instance=self)
260 def handle_failure(
261 self,
262 error: None | str | Exception | KeyboardInterrupt,
263 test_mode: bool | None = None,
264 context: Context | None = None,
265 force_fail: bool = False,
266 session: Session | None = None,
267 ) -> None:
268 """
269 Handle Failure for a task instance.
271 :param error: if specified, log the specific exception if thrown
272 :param session: SQLAlchemy ORM Session
273 :param test_mode: doesn't record success or failure in the DB if True
274 :param context: Jinja2 context
275 :param force_fail: if True, task does not retry
276 """
277 from airflow.models.taskinstance import _handle_failure
279 if TYPE_CHECKING:
280 assert self.task
281 assert self.task.dag
282 try:
283 fail_stop = self.task.dag.fail_stop
284 except Exception:
285 fail_stop = False
286 _handle_failure(
287 task_instance=self,
288 error=error,
289 session=session,
290 test_mode=test_mode,
291 context=context,
292 force_fail=force_fail,
293 fail_stop=fail_stop,
294 )
296 def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None:
297 """
298 Copy common attributes from the given task.
300 :param task: The task object to copy from
301 :param pool_override: Use the pool_override instead of task's pool
302 """
303 from airflow.models.taskinstance import _refresh_from_task
305 _refresh_from_task(task_instance=self, task=task, pool_override=pool_override)
307 def get_previous_dagrun(
308 self,
309 state: DagRunState | None = None,
310 session: Session | None = None,
311 ) -> DagRun | None:
312 """
313 Return the DagRun that ran before this task instance's DagRun.
315 :param state: If passed, it only take into account instances of a specific state.
316 :param session: SQLAlchemy ORM Session.
317 """
318 from airflow.models.taskinstance import _get_previous_dagrun
320 return _get_previous_dagrun(task_instance=self, state=state, session=session)
322 def get_previous_execution_date(
323 self,
324 state: DagRunState | None = None,
325 session: Session | None = None,
326 ) -> pendulum.DateTime | None:
327 """
328 Return the execution date from property previous_ti_success.
330 :param state: If passed, it only take into account instances of a specific state.
331 :param session: SQLAlchemy ORM Session
332 """
333 from airflow.models.taskinstance import _get_previous_execution_date
335 return _get_previous_execution_date(task_instance=self, state=state, session=session)
337 def email_alert(self, exception, task: BaseOperator) -> None:
338 """
339 Send alert email with exception information.
341 :param exception: the exception
342 :param task: task related to the exception
343 """
344 from airflow.models.taskinstance import _email_alert
346 _email_alert(task_instance=self, exception=exception, task=task)
348 def get_email_subject_content(
349 self, exception: BaseException, task: BaseOperator | None = None
350 ) -> tuple[str, str, str]:
351 """
352 Get the email subject content for exceptions.
354 :param exception: the exception sent in the email
355 :param task:
356 """
357 from airflow.models.taskinstance import _get_email_subject_content
359 return _get_email_subject_content(task_instance=self, exception=exception, task=task)
361 def get_previous_ti(
362 self,
363 state: DagRunState | None = None,
364 session: Session | None = None,
365 ) -> TaskInstance | TaskInstancePydantic | None:
366 """
367 Return the task instance for the task that ran before this task instance.
369 :param session: SQLAlchemy ORM Session
370 :param state: If passed, it only take into account instances of a specific state.
371 """
372 from airflow.models.taskinstance import _get_previous_ti
374 return _get_previous_ti(task_instance=self, state=state, session=session)
376 def check_and_change_state_before_execution(
377 self,
378 verbose: bool = True,
379 ignore_all_deps: bool = False,
380 ignore_depends_on_past: bool = False,
381 wait_for_past_depends_before_skipping: bool = False,
382 ignore_task_deps: bool = False,
383 ignore_ti_state: bool = False,
384 mark_success: bool = False,
385 test_mode: bool = False,
386 job_id: str | None = None,
387 pool: str | None = None,
388 external_executor_id: str | None = None,
389 session: Session | None = None,
390 ) -> bool:
391 return TaskInstance._check_and_change_state_before_execution(
392 task_instance=self,
393 verbose=verbose,
394 ignore_all_deps=ignore_all_deps,
395 ignore_depends_on_past=ignore_depends_on_past,
396 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
397 ignore_task_deps=ignore_task_deps,
398 ignore_ti_state=ignore_ti_state,
399 mark_success=mark_success,
400 test_mode=test_mode,
401 hostname=get_hostname(),
402 job_id=job_id,
403 pool=pool,
404 external_executor_id=external_executor_id,
405 session=session,
406 )
408 def schedule_downstream_tasks(self, session: Session | None = None, max_tis_per_query: int | None = None):
409 """
410 Schedule downstream tasks of this task instance.
412 :meta: private
413 """
414 return TaskInstance._schedule_downstream_tasks(
415 ti=self, session=session, max_tis_per_query=max_tis_per_query
416 )
418 def command_as_list(
419 self,
420 mark_success: bool = False,
421 ignore_all_deps: bool = False,
422 ignore_task_deps: bool = False,
423 ignore_depends_on_past: bool = False,
424 wait_for_past_depends_before_skipping: bool = False,
425 ignore_ti_state: bool = False,
426 local: bool = False,
427 pickle_id: int | None = None,
428 raw: bool = False,
429 job_id: str | None = None,
430 pool: str | None = None,
431 cfg_path: str | None = None,
432 ) -> list[str]:
433 """
434 Return a command that can be executed anywhere where airflow is installed.
436 This command is part of the message sent to executors by the orchestrator.
437 """
438 return TaskInstance._command_as_list(
439 ti=self,
440 mark_success=mark_success,
441 ignore_all_deps=ignore_all_deps,
442 ignore_task_deps=ignore_task_deps,
443 ignore_depends_on_past=ignore_depends_on_past,
444 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
445 ignore_ti_state=ignore_ti_state,
446 local=local,
447 pickle_id=pickle_id,
448 raw=raw,
449 job_id=job_id,
450 pool=pool,
451 cfg_path=cfg_path,
452 )
455if is_pydantic_2_installed():
456 TaskInstancePydantic.model_rebuild()