Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/serialization/pydantic/taskinstance.py: 57%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

152 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19from datetime import datetime 

20from typing import TYPE_CHECKING, Any, Iterable, Optional 

21 

22from typing_extensions import Annotated 

23 

24from airflow.models import Operator 

25from airflow.models.baseoperator import BaseOperator 

26from airflow.models.taskinstance import TaskInstance 

27from airflow.serialization.pydantic.dag import DagModelPydantic 

28from airflow.serialization.pydantic.dag_run import DagRunPydantic 

29from airflow.utils.log.logging_mixin import LoggingMixin 

30from airflow.utils.net import get_hostname 

31from airflow.utils.pydantic import ( 

32 BaseModel as BaseModelPydantic, 

33 ConfigDict, 

34 PlainSerializer, 

35 PlainValidator, 

36 is_pydantic_2_installed, 

37) 

38from airflow.utils.xcom import XCOM_RETURN_KEY 

39 

40if TYPE_CHECKING: 

41 import pendulum 

42 from sqlalchemy.orm import Session 

43 

44 from airflow.models.dagrun import DagRun 

45 from airflow.utils.context import Context 

46 from airflow.utils.pydantic import ValidationInfo 

47 from airflow.utils.state import DagRunState 

48 

49 

50def serialize_operator(x: Operator | None) -> dict | None: 

51 if x: 

52 from airflow.serialization.serialized_objects import SerializedBaseOperator 

53 

54 return SerializedBaseOperator.serialize_operator(x) 

55 return None 

56 

57 

58def validated_operator(x: dict[str, Any] | Operator, _info: ValidationInfo) -> Any: 

59 from airflow.models.baseoperator import BaseOperator 

60 from airflow.models.mappedoperator import MappedOperator 

61 from airflow.serialization.serialized_objects import SerializedBaseOperator 

62 

63 if isinstance(x, BaseOperator) or isinstance(x, MappedOperator) or x is None: 

64 return x 

65 return SerializedBaseOperator.deserialize_operator(x) 

66 

67 

68PydanticOperator = Annotated[ 

69 Operator, 

70 PlainValidator(validated_operator), 

71 PlainSerializer(serialize_operator, return_type=dict), 

72] 

73 

74 

75class TaskInstancePydantic(BaseModelPydantic, LoggingMixin): 

76 """Serializable representation of the TaskInstance ORM SqlAlchemyModel used by internal API.""" 

77 

78 task_id: str 

79 dag_id: str 

80 run_id: str 

81 map_index: int 

82 start_date: Optional[datetime] 

83 end_date: Optional[datetime] 

84 execution_date: Optional[datetime] 

85 duration: Optional[float] 

86 state: Optional[str] 

87 try_number: int 

88 max_tries: int 

89 hostname: str 

90 unixname: str 

91 job_id: Optional[int] 

92 pool: str 

93 pool_slots: int 

94 queue: str 

95 priority_weight: Optional[int] 

96 operator: str 

97 custom_operator_name: Optional[str] 

98 queued_dttm: Optional[datetime] 

99 queued_by_job_id: Optional[int] 

100 pid: Optional[int] 

101 executor: Optional[str] 

102 executor_config: Any 

103 updated_at: Optional[datetime] 

104 rendered_map_index: Optional[str] 

105 external_executor_id: Optional[str] 

106 trigger_id: Optional[int] 

107 trigger_timeout: Optional[datetime] 

108 next_method: Optional[str] 

109 next_kwargs: Optional[dict] 

110 run_as_user: Optional[str] 

111 task: Optional[PydanticOperator] 

112 test_mode: bool 

113 dag_run: Optional[DagRunPydantic] 

114 dag_model: Optional[DagModelPydantic] 

115 raw: Optional[bool] 

116 is_trigger_log_context: Optional[bool] 

117 model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) 

118 

119 @property 

120 def _logger_name(self): 

121 return "airflow.task" 

122 

123 def clear_xcom_data(self, session: Session | None = None): 

124 TaskInstance._clear_xcom_data(ti=self, session=session) 

125 

126 def set_state(self, state, session: Session | None = None) -> bool: 

127 return TaskInstance._set_state(ti=self, state=state, session=session) 

128 

129 def _run_execute_callback(self, context, task): 

130 TaskInstance._run_execute_callback(self=self, context=context, task=task) # type: ignore[arg-type] 

131 

132 def render_templates(self, context: Context | None = None, jinja_env=None): 

133 return TaskInstance.render_templates(self=self, context=context, jinja_env=jinja_env) # type: ignore[arg-type] 

134 

135 def init_run_context(self, raw: bool = False) -> None: 

136 """Set the log context.""" 

137 self.raw = raw 

138 self._set_context(self) 

139 

140 def xcom_pull( 

141 self, 

142 task_ids: str | Iterable[str] | None = None, 

143 dag_id: str | None = None, 

144 key: str = XCOM_RETURN_KEY, 

145 include_prior_dates: bool = False, 

146 *, 

147 map_indexes: int | Iterable[int] | None = None, 

148 default: Any = None, 

149 ) -> Any: 

150 """ 

151 Pull an XCom value for this task instance. 

152 

153 TODO: make it works for AIP-44 

154 :param task_ids: task id or list of task ids, if None, the task_id of the current task is used 

155 :param dag_id: dag id, if None, the dag_id of the current task is used 

156 :param key: the key to identify the XCom value 

157 :param include_prior_dates: whether to include prior execution dates 

158 :param map_indexes: map index or list of map indexes, if None, the map_index of the current task 

159 is used 

160 :param default: the default value to return if the XCom value does not exist 

161 :return: Xcom value 

162 """ 

163 return None 

164 

165 def xcom_push( 

166 self, 

167 key: str, 

168 value: Any, 

169 execution_date: datetime | None = None, 

170 session: Session | None = None, 

171 ) -> None: 

172 """ 

173 Push an XCom value for this task instance. 

174 

175 TODO: make it works for AIP-44 

176 :param key: the key to identify the XCom value 

177 :param value: the value of the XCom 

178 :param execution_date: the execution date to push the XCom for 

179 """ 

180 pass 

181 

182 def get_dagrun(self, session: Session | None = None) -> DagRunPydantic: 

183 """ 

184 Return the DagRun for this TaskInstance. 

185 

186 :param session: SQLAlchemy ORM Session 

187 

188 :return: Pydantic serialized version of DagRun 

189 """ 

190 return TaskInstance._get_dagrun(dag_id=self.dag_id, run_id=self.run_id, session=session) 

191 

192 def _execute_task(self, context, task_orig): 

193 """ 

194 Execute Task (optionally with a Timeout) and push Xcom results. 

195 

196 :param context: Jinja2 context 

197 :param task_orig: origin task 

198 """ 

199 from airflow.models.taskinstance import _execute_task 

200 

201 return _execute_task(task_instance=self, context=context, task_orig=task_orig) 

202 

203 def refresh_from_db(self, session: Session | None = None, lock_for_update: bool = False) -> None: 

204 """ 

205 Refresh the task instance from the database based on the primary key. 

206 

207 :param session: SQLAlchemy ORM Session 

208 :param lock_for_update: if True, indicates that the database should 

209 lock the TaskInstance (issuing a FOR UPDATE clause) until the 

210 session is committed. 

211 """ 

212 from airflow.models.taskinstance import _refresh_from_db 

213 

214 _refresh_from_db(task_instance=self, session=session, lock_for_update=lock_for_update) 

215 

216 def set_duration(self) -> None: 

217 """Set task instance duration.""" 

218 from airflow.models.taskinstance import _set_duration 

219 

220 _set_duration(task_instance=self) 

221 

222 @property 

223 def stats_tags(self) -> dict[str, str]: 

224 """Return task instance tags.""" 

225 from airflow.models.taskinstance import _stats_tags 

226 

227 return _stats_tags(task_instance=self) 

228 

229 def clear_next_method_args(self) -> None: 

230 """Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them.""" 

231 from airflow.models.taskinstance import _clear_next_method_args 

232 

233 _clear_next_method_args(task_instance=self) 

234 

235 def get_template_context( 

236 self, 

237 session: Session | None = None, 

238 ignore_param_exceptions: bool = True, 

239 ) -> Context: 

240 """ 

241 Return TI Context. 

242 

243 :param session: SQLAlchemy ORM Session 

244 :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict 

245 """ 

246 from airflow.models.taskinstance import _get_template_context 

247 

248 return _get_template_context( 

249 task_instance=self, 

250 session=session, 

251 ignore_param_exceptions=ignore_param_exceptions, 

252 ) 

253 

254 def is_eligible_to_retry(self): 

255 """Is task instance is eligible for retry.""" 

256 from airflow.models.taskinstance import _is_eligible_to_retry 

257 

258 return _is_eligible_to_retry(task_instance=self) 

259 

260 def handle_failure( 

261 self, 

262 error: None | str | Exception | KeyboardInterrupt, 

263 test_mode: bool | None = None, 

264 context: Context | None = None, 

265 force_fail: bool = False, 

266 session: Session | None = None, 

267 ) -> None: 

268 """ 

269 Handle Failure for a task instance. 

270 

271 :param error: if specified, log the specific exception if thrown 

272 :param session: SQLAlchemy ORM Session 

273 :param test_mode: doesn't record success or failure in the DB if True 

274 :param context: Jinja2 context 

275 :param force_fail: if True, task does not retry 

276 """ 

277 from airflow.models.taskinstance import _handle_failure 

278 

279 if TYPE_CHECKING: 

280 assert self.task 

281 assert self.task.dag 

282 try: 

283 fail_stop = self.task.dag.fail_stop 

284 except Exception: 

285 fail_stop = False 

286 _handle_failure( 

287 task_instance=self, 

288 error=error, 

289 session=session, 

290 test_mode=test_mode, 

291 context=context, 

292 force_fail=force_fail, 

293 fail_stop=fail_stop, 

294 ) 

295 

296 def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None: 

297 """ 

298 Copy common attributes from the given task. 

299 

300 :param task: The task object to copy from 

301 :param pool_override: Use the pool_override instead of task's pool 

302 """ 

303 from airflow.models.taskinstance import _refresh_from_task 

304 

305 _refresh_from_task(task_instance=self, task=task, pool_override=pool_override) 

306 

307 def get_previous_dagrun( 

308 self, 

309 state: DagRunState | None = None, 

310 session: Session | None = None, 

311 ) -> DagRun | None: 

312 """ 

313 Return the DagRun that ran before this task instance's DagRun. 

314 

315 :param state: If passed, it only take into account instances of a specific state. 

316 :param session: SQLAlchemy ORM Session. 

317 """ 

318 from airflow.models.taskinstance import _get_previous_dagrun 

319 

320 return _get_previous_dagrun(task_instance=self, state=state, session=session) 

321 

322 def get_previous_execution_date( 

323 self, 

324 state: DagRunState | None = None, 

325 session: Session | None = None, 

326 ) -> pendulum.DateTime | None: 

327 """ 

328 Return the execution date from property previous_ti_success. 

329 

330 :param state: If passed, it only take into account instances of a specific state. 

331 :param session: SQLAlchemy ORM Session 

332 """ 

333 from airflow.models.taskinstance import _get_previous_execution_date 

334 

335 return _get_previous_execution_date(task_instance=self, state=state, session=session) 

336 

337 def email_alert(self, exception, task: BaseOperator) -> None: 

338 """ 

339 Send alert email with exception information. 

340 

341 :param exception: the exception 

342 :param task: task related to the exception 

343 """ 

344 from airflow.models.taskinstance import _email_alert 

345 

346 _email_alert(task_instance=self, exception=exception, task=task) 

347 

348 def get_email_subject_content( 

349 self, exception: BaseException, task: BaseOperator | None = None 

350 ) -> tuple[str, str, str]: 

351 """ 

352 Get the email subject content for exceptions. 

353 

354 :param exception: the exception sent in the email 

355 :param task: 

356 """ 

357 from airflow.models.taskinstance import _get_email_subject_content 

358 

359 return _get_email_subject_content(task_instance=self, exception=exception, task=task) 

360 

361 def get_previous_ti( 

362 self, 

363 state: DagRunState | None = None, 

364 session: Session | None = None, 

365 ) -> TaskInstance | TaskInstancePydantic | None: 

366 """ 

367 Return the task instance for the task that ran before this task instance. 

368 

369 :param session: SQLAlchemy ORM Session 

370 :param state: If passed, it only take into account instances of a specific state. 

371 """ 

372 from airflow.models.taskinstance import _get_previous_ti 

373 

374 return _get_previous_ti(task_instance=self, state=state, session=session) 

375 

376 def check_and_change_state_before_execution( 

377 self, 

378 verbose: bool = True, 

379 ignore_all_deps: bool = False, 

380 ignore_depends_on_past: bool = False, 

381 wait_for_past_depends_before_skipping: bool = False, 

382 ignore_task_deps: bool = False, 

383 ignore_ti_state: bool = False, 

384 mark_success: bool = False, 

385 test_mode: bool = False, 

386 job_id: str | None = None, 

387 pool: str | None = None, 

388 external_executor_id: str | None = None, 

389 session: Session | None = None, 

390 ) -> bool: 

391 return TaskInstance._check_and_change_state_before_execution( 

392 task_instance=self, 

393 verbose=verbose, 

394 ignore_all_deps=ignore_all_deps, 

395 ignore_depends_on_past=ignore_depends_on_past, 

396 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

397 ignore_task_deps=ignore_task_deps, 

398 ignore_ti_state=ignore_ti_state, 

399 mark_success=mark_success, 

400 test_mode=test_mode, 

401 hostname=get_hostname(), 

402 job_id=job_id, 

403 pool=pool, 

404 external_executor_id=external_executor_id, 

405 session=session, 

406 ) 

407 

408 def schedule_downstream_tasks(self, session: Session | None = None, max_tis_per_query: int | None = None): 

409 """ 

410 Schedule downstream tasks of this task instance. 

411 

412 :meta: private 

413 """ 

414 return TaskInstance._schedule_downstream_tasks( 

415 ti=self, session=session, max_tis_per_query=max_tis_per_query 

416 ) 

417 

418 def command_as_list( 

419 self, 

420 mark_success: bool = False, 

421 ignore_all_deps: bool = False, 

422 ignore_task_deps: bool = False, 

423 ignore_depends_on_past: bool = False, 

424 wait_for_past_depends_before_skipping: bool = False, 

425 ignore_ti_state: bool = False, 

426 local: bool = False, 

427 pickle_id: int | None = None, 

428 raw: bool = False, 

429 job_id: str | None = None, 

430 pool: str | None = None, 

431 cfg_path: str | None = None, 

432 ) -> list[str]: 

433 """ 

434 Return a command that can be executed anywhere where airflow is installed. 

435 

436 This command is part of the message sent to executors by the orchestrator. 

437 """ 

438 return TaskInstance._command_as_list( 

439 ti=self, 

440 mark_success=mark_success, 

441 ignore_all_deps=ignore_all_deps, 

442 ignore_task_deps=ignore_task_deps, 

443 ignore_depends_on_past=ignore_depends_on_past, 

444 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

445 ignore_ti_state=ignore_ti_state, 

446 local=local, 

447 pickle_id=pickle_id, 

448 raw=raw, 

449 job_id=job_id, 

450 pool=pool, 

451 cfg_path=cfg_path, 

452 ) 

453 

454 

455if is_pydantic_2_installed(): 

456 TaskInstancePydantic.model_rebuild()