Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/abstractoperator.py: 31%

264 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import datetime 

21import inspect 

22from functools import cached_property 

23from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence 

24 

25from airflow.compat.functools import cache 

26from airflow.configuration import conf 

27from airflow.exceptions import AirflowException 

28from airflow.models.expandinput import NotFullyPopulated 

29from airflow.models.taskmixin import DAGNode 

30from airflow.template.templater import Templater 

31from airflow.utils.context import Context 

32from airflow.utils.log.secrets_masker import redact 

33from airflow.utils.session import NEW_SESSION, provide_session 

34from airflow.utils.sqlalchemy import skip_locked, with_row_locks 

35from airflow.utils.state import State, TaskInstanceState 

36from airflow.utils.task_group import MappedTaskGroup 

37from airflow.utils.trigger_rule import TriggerRule 

38from airflow.utils.weight_rule import WeightRule 

39 

40TaskStateChangeCallback = Callable[[Context], None] 

41 

42if TYPE_CHECKING: 

43 import jinja2 # Slow import. 

44 from sqlalchemy.orm import Session 

45 

46 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink 

47 from airflow.models.dag import DAG 

48 from airflow.models.mappedoperator import MappedOperator 

49 from airflow.models.operator import Operator 

50 from airflow.models.taskinstance import TaskInstance 

51 

52DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") 

53DEFAULT_POOL_SLOTS: int = 1 

54DEFAULT_PRIORITY_WEIGHT: int = 1 

55DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") 

56DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = conf.getboolean( 

57 "scheduler", "ignore_first_depends_on_past_by_default" 

58) 

59DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False 

60DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) 

61DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( 

62 seconds=conf.getint("core", "default_task_retry_delay", fallback=300) 

63) 

64MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) 

65 

66DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( 

67 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) 

68) 

69DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS 

70DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( 

71 "core", "default_task_execution_timeout" 

72) 

73 

74 

75class NotMapped(Exception): 

76 """Raise if a task is neither mapped nor has any parent mapped groups.""" 

77 

78 

79class AbstractOperator(Templater, DAGNode): 

80 """Common implementation for operators, including unmapped and mapped. 

81 

82 This base class is more about sharing implementations, not defining a common 

83 interface. Unfortunately it's difficult to use this as the common base class 

84 for typing due to BaseOperator carrying too much historical baggage. 

85 

86 The union type ``from airflow.models.operator import Operator`` is easier 

87 to use for typing purposes. 

88 

89 :meta private: 

90 """ 

91 

92 operator_class: type[BaseOperator] | dict[str, Any] 

93 

94 weight_rule: str 

95 priority_weight: int 

96 

97 # Defines the operator level extra links. 

98 operator_extra_links: Collection[BaseOperatorLink] 

99 

100 owner: str 

101 task_id: str 

102 

103 outlets: list 

104 inlets: list 

105 

106 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( 

107 ( 

108 "log", 

109 "dag", # We show dag_id, don't need to show this too 

110 "node_id", # Duplicates task_id 

111 "task_group", # Doesn't have a useful repr, no point showing in UI 

112 "inherits_from_empty_operator", # impl detail 

113 # For compatibility with TG, for operators these are just the current task, no point showing 

114 "roots", 

115 "leaves", 

116 # These lists are already shown via *_task_ids 

117 "upstream_list", 

118 "downstream_list", 

119 # Not useful, implementation detail, already shown elsewhere 

120 "global_operator_extra_link_dict", 

121 "operator_extra_link_dict", 

122 ) 

123 ) 

124 

125 def get_dag(self) -> DAG | None: 

126 raise NotImplementedError() 

127 

128 @property 

129 def task_type(self) -> str: 

130 raise NotImplementedError() 

131 

132 @property 

133 def operator_name(self) -> str: 

134 raise NotImplementedError() 

135 

136 @property 

137 def inherits_from_empty_operator(self) -> bool: 

138 raise NotImplementedError() 

139 

140 @property 

141 def dag_id(self) -> str: 

142 """Returns dag id if it has one or an adhoc + owner.""" 

143 dag = self.get_dag() 

144 if dag: 

145 return dag.dag_id 

146 return f"adhoc_{self.owner}" 

147 

148 @property 

149 def node_id(self) -> str: 

150 return self.task_id 

151 

152 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: 

153 """Get direct relative IDs to the current task, upstream or downstream.""" 

154 if upstream: 

155 return self.upstream_task_ids 

156 return self.downstream_task_ids 

157 

158 def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: 

159 """ 

160 Get a flat set of relative IDs, upstream or downstream. 

161 

162 Will recurse each relative found in the direction specified. 

163 

164 :param upstream: Whether to look for upstream or downstream relatives. 

165 """ 

166 dag = self.get_dag() 

167 if not dag: 

168 return set() 

169 

170 relatives: set[str] = set() 

171 

172 task_ids_to_trace = self.get_direct_relative_ids(upstream) 

173 while task_ids_to_trace: 

174 task_ids_to_trace_next: set[str] = set() 

175 for task_id in task_ids_to_trace: 

176 if task_id in relatives: 

177 continue 

178 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) 

179 relatives.add(task_id) 

180 task_ids_to_trace = task_ids_to_trace_next 

181 

182 return relatives 

183 

184 def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]: 

185 """Get a flat list of relatives, either upstream or downstream.""" 

186 dag = self.get_dag() 

187 if not dag: 

188 return set() 

189 return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)] 

190 

191 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

192 """Return mapped nodes that are direct dependencies of the current task. 

193 

194 For now, this walks the entire DAG to find mapped nodes that has this 

195 current task as an upstream. We cannot use ``downstream_list`` since it 

196 only contains operators, not task groups. In the future, we should 

197 provide a way to record an DAG node's all downstream nodes instead. 

198 

199 Note that this does not guarantee the returned tasks actually use the 

200 current task for task mapping, but only checks those task are mapped 

201 operators, and are downstreams of the current task. 

202 

203 To get a list of tasks that uses the current task for task mapping, use 

204 :meth:`iter_mapped_dependants` instead. 

205 """ 

206 from airflow.models.mappedoperator import MappedOperator 

207 from airflow.utils.task_group import TaskGroup 

208 

209 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: 

210 """Recursively walk children in a task group. 

211 

212 This yields all direct children (including both tasks and task 

213 groups), and all children of any task groups. 

214 """ 

215 for key, child in group.children.items(): 

216 yield key, child 

217 if isinstance(child, TaskGroup): 

218 yield from _walk_group(child) 

219 

220 dag = self.get_dag() 

221 if not dag: 

222 raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG") 

223 for key, child in _walk_group(dag.task_group): 

224 if key == self.node_id: 

225 continue 

226 if not isinstance(child, (MappedOperator, MappedTaskGroup)): 

227 continue 

228 if self.node_id in child.upstream_task_ids: 

229 yield child 

230 

231 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

232 """Return mapped nodes that depend on the current task the expansion. 

233 

234 For now, this walks the entire DAG to find mapped nodes that has this 

235 current task as an upstream. We cannot use ``downstream_list`` since it 

236 only contains operators, not task groups. In the future, we should 

237 provide a way to record an DAG node's all downstream nodes instead. 

238 """ 

239 return ( 

240 downstream 

241 for downstream in self._iter_all_mapped_downstreams() 

242 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) 

243 ) 

244 

245 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

246 """Return mapped task groups this task belongs to. 

247 

248 Groups are returned from the innermost to the outmost. 

249 

250 :meta private: 

251 """ 

252 parent = self.task_group 

253 while parent is not None: 

254 if isinstance(parent, MappedTaskGroup): 

255 yield parent 

256 parent = parent.task_group 

257 

258 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: 

259 """Get the mapped task group "closest" to this task in the DAG. 

260 

261 :meta private: 

262 """ 

263 return next(self.iter_mapped_task_groups(), None) 

264 

265 def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator: 

266 """Get the "normal" operator from current abstract operator. 

267 

268 MappedOperator uses this to unmap itself based on the map index. A non- 

269 mapped operator (i.e. BaseOperator subclass) simply returns itself. 

270 

271 :meta private: 

272 """ 

273 raise NotImplementedError() 

274 

275 @property 

276 def priority_weight_total(self) -> int: 

277 """ 

278 Total priority weight for the task. It might include all upstream or downstream tasks. 

279 

280 Depending on the weight rule: 

281 

282 - WeightRule.ABSOLUTE - only own weight 

283 - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks 

284 - WeightRule.UPSTREAM - adds priority weight of all upstream tasks 

285 """ 

286 if self.weight_rule == WeightRule.ABSOLUTE: 

287 return self.priority_weight 

288 elif self.weight_rule == WeightRule.DOWNSTREAM: 

289 upstream = False 

290 elif self.weight_rule == WeightRule.UPSTREAM: 

291 upstream = True 

292 else: 

293 upstream = False 

294 dag = self.get_dag() 

295 if dag is None: 

296 return self.priority_weight 

297 return self.priority_weight + sum( 

298 dag.task_dict[task_id].priority_weight 

299 for task_id in self.get_flat_relative_ids(upstream=upstream) 

300 ) 

301 

302 @cached_property 

303 def operator_extra_link_dict(self) -> dict[str, Any]: 

304 """Returns dictionary of all extra links for the operator.""" 

305 op_extra_links_from_plugin: dict[str, Any] = {} 

306 from airflow import plugins_manager 

307 

308 plugins_manager.initialize_extra_operators_links_plugins() 

309 if plugins_manager.operator_extra_links is None: 

310 raise AirflowException("Can't load operators") 

311 for ope in plugins_manager.operator_extra_links: 

312 if ope.operators and self.operator_class in ope.operators: 

313 op_extra_links_from_plugin.update({ope.name: ope}) 

314 

315 operator_extra_links_all = {link.name: link for link in self.operator_extra_links} 

316 # Extra links defined in Plugins overrides operator links defined in operator 

317 operator_extra_links_all.update(op_extra_links_from_plugin) 

318 

319 return operator_extra_links_all 

320 

321 @cached_property 

322 def global_operator_extra_link_dict(self) -> dict[str, Any]: 

323 """Returns dictionary of all global extra links.""" 

324 from airflow import plugins_manager 

325 

326 plugins_manager.initialize_extra_operators_links_plugins() 

327 if plugins_manager.global_operator_extra_links is None: 

328 raise AirflowException("Can't load operators") 

329 return {link.name: link for link in plugins_manager.global_operator_extra_links} 

330 

331 @cached_property 

332 def extra_links(self) -> list[str]: 

333 return list(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict)) 

334 

335 def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None: 

336 """For an operator, gets the URLs that the ``extra_links`` entry points to. 

337 

338 :meta private: 

339 

340 :raise ValueError: The error message of a ValueError will be passed on through to 

341 the fronted to show up as a tooltip on the disabled link. 

342 :param ti: The TaskInstance for the URL being searched for. 

343 :param link_name: The name of the link we're looking for the URL for. Should be 

344 one of the options specified in ``extra_links``. 

345 """ 

346 link: BaseOperatorLink | None = self.operator_extra_link_dict.get(link_name) 

347 if not link: 

348 link = self.global_operator_extra_link_dict.get(link_name) 

349 if not link: 

350 return None 

351 

352 parameters = inspect.signature(link.get_link).parameters 

353 old_signature = all(name != "ti_key" for name, p in parameters.items() if p.kind != p.VAR_KEYWORD) 

354 

355 if old_signature: 

356 return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc] 

357 return link.get_link(self.unmap(None), ti_key=ti.key) 

358 

359 @cache 

360 def get_parse_time_mapped_ti_count(self) -> int: 

361 """Number of mapped task instances that can be created on DAG run creation. 

362 

363 This only considers literal mapped arguments, and would return *None* 

364 when any non-literal values are used for mapping. 

365 

366 :raise NotFullyPopulated: If non-literal mapped arguments are encountered. 

367 :raise NotMapped: If the operator is neither mapped, nor has any parent 

368 mapped task groups. 

369 :return: Total number of mapped TIs this task should have. 

370 """ 

371 group = self.get_closest_mapped_task_group() 

372 if group is None: 

373 raise NotMapped 

374 return group.get_parse_time_mapped_ti_count() 

375 

376 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: 

377 """Number of mapped TaskInstances that can be created at run time. 

378 

379 This considers both literal and non-literal mapped arguments, and the 

380 result is therefore available when all depended tasks have finished. The 

381 return value should be identical to ``parse_time_mapped_ti_count`` if 

382 all mapped arguments are literal. 

383 

384 :raise NotFullyPopulated: If upstream tasks are not all complete yet. 

385 :raise NotMapped: If the operator is neither mapped, nor has any parent 

386 mapped task groups. 

387 :return: Total number of mapped TIs this task should have. 

388 """ 

389 group = self.get_closest_mapped_task_group() 

390 if group is None: 

391 raise NotMapped 

392 return group.get_mapped_ti_count(run_id, session=session) 

393 

394 def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]: 

395 """Create the mapped task instances for mapped task. 

396 

397 :raise NotMapped: If this task does not need expansion. 

398 :return: The newly created mapped task instances (if any) in ascending 

399 order by map index, and the maximum map index value. 

400 """ 

401 from sqlalchemy import func, or_ 

402 

403 from airflow.models.baseoperator import BaseOperator 

404 from airflow.models.mappedoperator import MappedOperator 

405 from airflow.models.taskinstance import TaskInstance 

406 from airflow.settings import task_instance_mutation_hook 

407 

408 if not isinstance(self, (BaseOperator, MappedOperator)): 

409 raise RuntimeError(f"cannot expand unrecognized operator type {type(self).__name__}") 

410 

411 try: 

412 total_length: int | None = self.get_mapped_ti_count(run_id, session=session) 

413 except NotFullyPopulated as e: 

414 # It's possible that the upstream tasks are not yet done, but we 

415 # don't have upstream of upstreams in partial DAGs (possible in the 

416 # mini-scheduler), so we ignore this exception. 

417 if not self.dag or not self.dag.partial: 

418 self.log.error( 

419 "Cannot expand %r for run %s; missing upstream values: %s", 

420 self, 

421 run_id, 

422 sorted(e.missing), 

423 ) 

424 total_length = None 

425 

426 state: TaskInstanceState | None = None 

427 unmapped_ti: TaskInstance | None = ( 

428 session.query(TaskInstance) 

429 .filter( 

430 TaskInstance.dag_id == self.dag_id, 

431 TaskInstance.task_id == self.task_id, 

432 TaskInstance.run_id == run_id, 

433 TaskInstance.map_index == -1, 

434 or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), 

435 ) 

436 .one_or_none() 

437 ) 

438 

439 all_expanded_tis: list[TaskInstance] = [] 

440 

441 if unmapped_ti: 

442 # The unmapped task instance still exists and is unfinished, i.e. we 

443 # haven't tried to run it before. 

444 if total_length is None: 

445 # If the DAG is partial, it's likely that the upstream tasks 

446 # are not done yet, so the task can't fail yet. 

447 if not self.dag or not self.dag.partial: 

448 unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED 

449 elif total_length < 1: 

450 # If the upstream maps this to a zero-length value, simply mark 

451 # the unmapped task instance as SKIPPED (if needed). 

452 self.log.info( 

453 "Marking %s as SKIPPED since the map has %d values to expand", 

454 unmapped_ti, 

455 total_length, 

456 ) 

457 unmapped_ti.state = TaskInstanceState.SKIPPED 

458 else: 

459 zero_index_ti_exists = ( 

460 session.query(TaskInstance) 

461 .filter( 

462 TaskInstance.dag_id == self.dag_id, 

463 TaskInstance.task_id == self.task_id, 

464 TaskInstance.run_id == run_id, 

465 TaskInstance.map_index == 0, 

466 ) 

467 .count() 

468 > 0 

469 ) 

470 if not zero_index_ti_exists: 

471 # Otherwise convert this into the first mapped index, and create 

472 # TaskInstance for other indexes. 

473 unmapped_ti.map_index = 0 

474 self.log.debug("Updated in place to become %s", unmapped_ti) 

475 all_expanded_tis.append(unmapped_ti) 

476 session.flush() 

477 else: 

478 self.log.debug("Deleting the original task instance: %s", unmapped_ti) 

479 session.delete(unmapped_ti) 

480 state = unmapped_ti.state 

481 

482 if total_length is None or total_length < 1: 

483 # Nothing to fixup. 

484 indexes_to_map: Iterable[int] = () 

485 else: 

486 # Only create "missing" ones. 

487 current_max_mapping = ( 

488 session.query(func.max(TaskInstance.map_index)) 

489 .filter( 

490 TaskInstance.dag_id == self.dag_id, 

491 TaskInstance.task_id == self.task_id, 

492 TaskInstance.run_id == run_id, 

493 ) 

494 .scalar() 

495 ) 

496 indexes_to_map = range(current_max_mapping + 1, total_length) 

497 

498 for index in indexes_to_map: 

499 # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. 

500 ti = TaskInstance(self, run_id=run_id, map_index=index, state=state) 

501 self.log.debug("Expanding TIs upserted %s", ti) 

502 task_instance_mutation_hook(ti) 

503 ti = session.merge(ti) 

504 ti.refresh_from_task(self) # session.merge() loses task information. 

505 all_expanded_tis.append(ti) 

506 

507 # Coerce the None case to 0 -- these two are almost treated identically, 

508 # except the unmapped ti (if exists) is marked to different states. 

509 total_expanded_ti_count = total_length or 0 

510 

511 # Any (old) task instances with inapplicable indexes (>= the total 

512 # number we need) are set to "REMOVED". 

513 query = session.query(TaskInstance).filter( 

514 TaskInstance.dag_id == self.dag_id, 

515 TaskInstance.task_id == self.task_id, 

516 TaskInstance.run_id == run_id, 

517 TaskInstance.map_index >= total_expanded_ti_count, 

518 ) 

519 to_update = with_row_locks(query, of=TaskInstance, session=session, **skip_locked(session=session)) 

520 for ti in to_update: 

521 ti.state = TaskInstanceState.REMOVED 

522 session.flush() 

523 return all_expanded_tis, total_expanded_ti_count - 1 

524 

525 def render_template_fields( 

526 self, 

527 context: Context, 

528 jinja_env: jinja2.Environment | None = None, 

529 ) -> None: 

530 """Template all attributes listed in *self.template_fields*. 

531 

532 If the operator is mapped, this should return the unmapped, fully 

533 rendered, and map-expanded operator. The mapped operator should not be 

534 modified. However, *context* may be modified in-place to reference the 

535 unmapped operator for template rendering. 

536 

537 If the operator is not mapped, this should modify the operator in-place. 

538 """ 

539 raise NotImplementedError() 

540 

541 def _render(self, template, context, dag: DAG | None = None): 

542 if dag is None: 

543 dag = self.get_dag() 

544 return super()._render(template, context, dag=dag) 

545 

546 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: 

547 """Get the template environment for rendering templates.""" 

548 if dag is None: 

549 dag = self.get_dag() 

550 return super().get_template_env(dag=dag) 

551 

552 @provide_session 

553 def _do_render_template_fields( 

554 self, 

555 parent: Any, 

556 template_fields: Iterable[str], 

557 context: Context, 

558 jinja_env: jinja2.Environment, 

559 seen_oids: set[int], 

560 *, 

561 session: Session = NEW_SESSION, 

562 ) -> None: 

563 """Override the base to use custom error logging.""" 

564 for attr_name in template_fields: 

565 try: 

566 value = getattr(parent, attr_name) 

567 except AttributeError: 

568 raise AttributeError( 

569 f"{attr_name!r} is configured as a template field " 

570 f"but {parent.task_type} does not have this attribute." 

571 ) 

572 

573 try: 

574 if not value: 

575 continue 

576 except Exception: 

577 # This may happen if the templated field points to a class which does not support `__bool__`, 

578 # such as Pandas DataFrames: 

579 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 

580 self.log.info( 

581 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", 

582 type(value).__name__, 

583 self.task_id, 

584 attr_name, 

585 ) 

586 # We may still want to render custom classes which do not support __bool__ 

587 pass 

588 

589 try: 

590 rendered_content = self.render_template( 

591 value, 

592 context, 

593 jinja_env, 

594 seen_oids, 

595 ) 

596 except Exception: 

597 value_masked = redact(name=attr_name, value=value) 

598 self.log.exception( 

599 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", 

600 self.task_id, 

601 attr_name, 

602 value_masked, 

603 ) 

604 raise 

605 else: 

606 setattr(parent, attr_name, rendered_content)