Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/abstractoperator.py: 31%
264 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import datetime
21import inspect
22from functools import cached_property
23from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence
25from airflow.compat.functools import cache
26from airflow.configuration import conf
27from airflow.exceptions import AirflowException
28from airflow.models.expandinput import NotFullyPopulated
29from airflow.models.taskmixin import DAGNode
30from airflow.template.templater import Templater
31from airflow.utils.context import Context
32from airflow.utils.log.secrets_masker import redact
33from airflow.utils.session import NEW_SESSION, provide_session
34from airflow.utils.sqlalchemy import skip_locked, with_row_locks
35from airflow.utils.state import State, TaskInstanceState
36from airflow.utils.task_group import MappedTaskGroup
37from airflow.utils.trigger_rule import TriggerRule
38from airflow.utils.weight_rule import WeightRule
40TaskStateChangeCallback = Callable[[Context], None]
42if TYPE_CHECKING:
43 import jinja2 # Slow import.
44 from sqlalchemy.orm import Session
46 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
47 from airflow.models.dag import DAG
48 from airflow.models.mappedoperator import MappedOperator
49 from airflow.models.operator import Operator
50 from airflow.models.taskinstance import TaskInstance
52DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
53DEFAULT_POOL_SLOTS: int = 1
54DEFAULT_PRIORITY_WEIGHT: int = 1
55DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue")
56DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = conf.getboolean(
57 "scheduler", "ignore_first_depends_on_past_by_default"
58)
59DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False
60DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
61DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(
62 seconds=conf.getint("core", "default_task_retry_delay", fallback=300)
63)
64MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60)
66DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
67 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
68)
69DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
70DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta(
71 "core", "default_task_execution_timeout"
72)
75class NotMapped(Exception):
76 """Raise if a task is neither mapped nor has any parent mapped groups."""
79class AbstractOperator(Templater, DAGNode):
80 """Common implementation for operators, including unmapped and mapped.
82 This base class is more about sharing implementations, not defining a common
83 interface. Unfortunately it's difficult to use this as the common base class
84 for typing due to BaseOperator carrying too much historical baggage.
86 The union type ``from airflow.models.operator import Operator`` is easier
87 to use for typing purposes.
89 :meta private:
90 """
92 operator_class: type[BaseOperator] | dict[str, Any]
94 weight_rule: str
95 priority_weight: int
97 # Defines the operator level extra links.
98 operator_extra_links: Collection[BaseOperatorLink]
100 owner: str
101 task_id: str
103 outlets: list
104 inlets: list
106 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
107 (
108 "log",
109 "dag", # We show dag_id, don't need to show this too
110 "node_id", # Duplicates task_id
111 "task_group", # Doesn't have a useful repr, no point showing in UI
112 "inherits_from_empty_operator", # impl detail
113 # For compatibility with TG, for operators these are just the current task, no point showing
114 "roots",
115 "leaves",
116 # These lists are already shown via *_task_ids
117 "upstream_list",
118 "downstream_list",
119 # Not useful, implementation detail, already shown elsewhere
120 "global_operator_extra_link_dict",
121 "operator_extra_link_dict",
122 )
123 )
125 def get_dag(self) -> DAG | None:
126 raise NotImplementedError()
128 @property
129 def task_type(self) -> str:
130 raise NotImplementedError()
132 @property
133 def operator_name(self) -> str:
134 raise NotImplementedError()
136 @property
137 def inherits_from_empty_operator(self) -> bool:
138 raise NotImplementedError()
140 @property
141 def dag_id(self) -> str:
142 """Returns dag id if it has one or an adhoc + owner."""
143 dag = self.get_dag()
144 if dag:
145 return dag.dag_id
146 return f"adhoc_{self.owner}"
148 @property
149 def node_id(self) -> str:
150 return self.task_id
152 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
153 """Get direct relative IDs to the current task, upstream or downstream."""
154 if upstream:
155 return self.upstream_task_ids
156 return self.downstream_task_ids
158 def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]:
159 """
160 Get a flat set of relative IDs, upstream or downstream.
162 Will recurse each relative found in the direction specified.
164 :param upstream: Whether to look for upstream or downstream relatives.
165 """
166 dag = self.get_dag()
167 if not dag:
168 return set()
170 relatives: set[str] = set()
172 task_ids_to_trace = self.get_direct_relative_ids(upstream)
173 while task_ids_to_trace:
174 task_ids_to_trace_next: set[str] = set()
175 for task_id in task_ids_to_trace:
176 if task_id in relatives:
177 continue
178 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream))
179 relatives.add(task_id)
180 task_ids_to_trace = task_ids_to_trace_next
182 return relatives
184 def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]:
185 """Get a flat list of relatives, either upstream or downstream."""
186 dag = self.get_dag()
187 if not dag:
188 return set()
189 return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)]
191 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
192 """Return mapped nodes that are direct dependencies of the current task.
194 For now, this walks the entire DAG to find mapped nodes that has this
195 current task as an upstream. We cannot use ``downstream_list`` since it
196 only contains operators, not task groups. In the future, we should
197 provide a way to record an DAG node's all downstream nodes instead.
199 Note that this does not guarantee the returned tasks actually use the
200 current task for task mapping, but only checks those task are mapped
201 operators, and are downstreams of the current task.
203 To get a list of tasks that uses the current task for task mapping, use
204 :meth:`iter_mapped_dependants` instead.
205 """
206 from airflow.models.mappedoperator import MappedOperator
207 from airflow.utils.task_group import TaskGroup
209 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
210 """Recursively walk children in a task group.
212 This yields all direct children (including both tasks and task
213 groups), and all children of any task groups.
214 """
215 for key, child in group.children.items():
216 yield key, child
217 if isinstance(child, TaskGroup):
218 yield from _walk_group(child)
220 dag = self.get_dag()
221 if not dag:
222 raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG")
223 for key, child in _walk_group(dag.task_group):
224 if key == self.node_id:
225 continue
226 if not isinstance(child, (MappedOperator, MappedTaskGroup)):
227 continue
228 if self.node_id in child.upstream_task_ids:
229 yield child
231 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]:
232 """Return mapped nodes that depend on the current task the expansion.
234 For now, this walks the entire DAG to find mapped nodes that has this
235 current task as an upstream. We cannot use ``downstream_list`` since it
236 only contains operators, not task groups. In the future, we should
237 provide a way to record an DAG node's all downstream nodes instead.
238 """
239 return (
240 downstream
241 for downstream in self._iter_all_mapped_downstreams()
242 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
243 )
245 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
246 """Return mapped task groups this task belongs to.
248 Groups are returned from the innermost to the outmost.
250 :meta private:
251 """
252 parent = self.task_group
253 while parent is not None:
254 if isinstance(parent, MappedTaskGroup):
255 yield parent
256 parent = parent.task_group
258 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
259 """Get the mapped task group "closest" to this task in the DAG.
261 :meta private:
262 """
263 return next(self.iter_mapped_task_groups(), None)
265 def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
266 """Get the "normal" operator from current abstract operator.
268 MappedOperator uses this to unmap itself based on the map index. A non-
269 mapped operator (i.e. BaseOperator subclass) simply returns itself.
271 :meta private:
272 """
273 raise NotImplementedError()
275 @property
276 def priority_weight_total(self) -> int:
277 """
278 Total priority weight for the task. It might include all upstream or downstream tasks.
280 Depending on the weight rule:
282 - WeightRule.ABSOLUTE - only own weight
283 - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks
284 - WeightRule.UPSTREAM - adds priority weight of all upstream tasks
285 """
286 if self.weight_rule == WeightRule.ABSOLUTE:
287 return self.priority_weight
288 elif self.weight_rule == WeightRule.DOWNSTREAM:
289 upstream = False
290 elif self.weight_rule == WeightRule.UPSTREAM:
291 upstream = True
292 else:
293 upstream = False
294 dag = self.get_dag()
295 if dag is None:
296 return self.priority_weight
297 return self.priority_weight + sum(
298 dag.task_dict[task_id].priority_weight
299 for task_id in self.get_flat_relative_ids(upstream=upstream)
300 )
302 @cached_property
303 def operator_extra_link_dict(self) -> dict[str, Any]:
304 """Returns dictionary of all extra links for the operator."""
305 op_extra_links_from_plugin: dict[str, Any] = {}
306 from airflow import plugins_manager
308 plugins_manager.initialize_extra_operators_links_plugins()
309 if plugins_manager.operator_extra_links is None:
310 raise AirflowException("Can't load operators")
311 for ope in plugins_manager.operator_extra_links:
312 if ope.operators and self.operator_class in ope.operators:
313 op_extra_links_from_plugin.update({ope.name: ope})
315 operator_extra_links_all = {link.name: link for link in self.operator_extra_links}
316 # Extra links defined in Plugins overrides operator links defined in operator
317 operator_extra_links_all.update(op_extra_links_from_plugin)
319 return operator_extra_links_all
321 @cached_property
322 def global_operator_extra_link_dict(self) -> dict[str, Any]:
323 """Returns dictionary of all global extra links."""
324 from airflow import plugins_manager
326 plugins_manager.initialize_extra_operators_links_plugins()
327 if plugins_manager.global_operator_extra_links is None:
328 raise AirflowException("Can't load operators")
329 return {link.name: link for link in plugins_manager.global_operator_extra_links}
331 @cached_property
332 def extra_links(self) -> list[str]:
333 return list(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))
335 def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
336 """For an operator, gets the URLs that the ``extra_links`` entry points to.
338 :meta private:
340 :raise ValueError: The error message of a ValueError will be passed on through to
341 the fronted to show up as a tooltip on the disabled link.
342 :param ti: The TaskInstance for the URL being searched for.
343 :param link_name: The name of the link we're looking for the URL for. Should be
344 one of the options specified in ``extra_links``.
345 """
346 link: BaseOperatorLink | None = self.operator_extra_link_dict.get(link_name)
347 if not link:
348 link = self.global_operator_extra_link_dict.get(link_name)
349 if not link:
350 return None
352 parameters = inspect.signature(link.get_link).parameters
353 old_signature = all(name != "ti_key" for name, p in parameters.items() if p.kind != p.VAR_KEYWORD)
355 if old_signature:
356 return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc]
357 return link.get_link(self.unmap(None), ti_key=ti.key)
359 @cache
360 def get_parse_time_mapped_ti_count(self) -> int:
361 """Number of mapped task instances that can be created on DAG run creation.
363 This only considers literal mapped arguments, and would return *None*
364 when any non-literal values are used for mapping.
366 :raise NotFullyPopulated: If non-literal mapped arguments are encountered.
367 :raise NotMapped: If the operator is neither mapped, nor has any parent
368 mapped task groups.
369 :return: Total number of mapped TIs this task should have.
370 """
371 group = self.get_closest_mapped_task_group()
372 if group is None:
373 raise NotMapped
374 return group.get_parse_time_mapped_ti_count()
376 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
377 """Number of mapped TaskInstances that can be created at run time.
379 This considers both literal and non-literal mapped arguments, and the
380 result is therefore available when all depended tasks have finished. The
381 return value should be identical to ``parse_time_mapped_ti_count`` if
382 all mapped arguments are literal.
384 :raise NotFullyPopulated: If upstream tasks are not all complete yet.
385 :raise NotMapped: If the operator is neither mapped, nor has any parent
386 mapped task groups.
387 :return: Total number of mapped TIs this task should have.
388 """
389 group = self.get_closest_mapped_task_group()
390 if group is None:
391 raise NotMapped
392 return group.get_mapped_ti_count(run_id, session=session)
394 def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
395 """Create the mapped task instances for mapped task.
397 :raise NotMapped: If this task does not need expansion.
398 :return: The newly created mapped task instances (if any) in ascending
399 order by map index, and the maximum map index value.
400 """
401 from sqlalchemy import func, or_
403 from airflow.models.baseoperator import BaseOperator
404 from airflow.models.mappedoperator import MappedOperator
405 from airflow.models.taskinstance import TaskInstance
406 from airflow.settings import task_instance_mutation_hook
408 if not isinstance(self, (BaseOperator, MappedOperator)):
409 raise RuntimeError(f"cannot expand unrecognized operator type {type(self).__name__}")
411 try:
412 total_length: int | None = self.get_mapped_ti_count(run_id, session=session)
413 except NotFullyPopulated as e:
414 # It's possible that the upstream tasks are not yet done, but we
415 # don't have upstream of upstreams in partial DAGs (possible in the
416 # mini-scheduler), so we ignore this exception.
417 if not self.dag or not self.dag.partial:
418 self.log.error(
419 "Cannot expand %r for run %s; missing upstream values: %s",
420 self,
421 run_id,
422 sorted(e.missing),
423 )
424 total_length = None
426 state: TaskInstanceState | None = None
427 unmapped_ti: TaskInstance | None = (
428 session.query(TaskInstance)
429 .filter(
430 TaskInstance.dag_id == self.dag_id,
431 TaskInstance.task_id == self.task_id,
432 TaskInstance.run_id == run_id,
433 TaskInstance.map_index == -1,
434 or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)),
435 )
436 .one_or_none()
437 )
439 all_expanded_tis: list[TaskInstance] = []
441 if unmapped_ti:
442 # The unmapped task instance still exists and is unfinished, i.e. we
443 # haven't tried to run it before.
444 if total_length is None:
445 # If the DAG is partial, it's likely that the upstream tasks
446 # are not done yet, so the task can't fail yet.
447 if not self.dag or not self.dag.partial:
448 unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
449 elif total_length < 1:
450 # If the upstream maps this to a zero-length value, simply mark
451 # the unmapped task instance as SKIPPED (if needed).
452 self.log.info(
453 "Marking %s as SKIPPED since the map has %d values to expand",
454 unmapped_ti,
455 total_length,
456 )
457 unmapped_ti.state = TaskInstanceState.SKIPPED
458 else:
459 zero_index_ti_exists = (
460 session.query(TaskInstance)
461 .filter(
462 TaskInstance.dag_id == self.dag_id,
463 TaskInstance.task_id == self.task_id,
464 TaskInstance.run_id == run_id,
465 TaskInstance.map_index == 0,
466 )
467 .count()
468 > 0
469 )
470 if not zero_index_ti_exists:
471 # Otherwise convert this into the first mapped index, and create
472 # TaskInstance for other indexes.
473 unmapped_ti.map_index = 0
474 self.log.debug("Updated in place to become %s", unmapped_ti)
475 all_expanded_tis.append(unmapped_ti)
476 session.flush()
477 else:
478 self.log.debug("Deleting the original task instance: %s", unmapped_ti)
479 session.delete(unmapped_ti)
480 state = unmapped_ti.state
482 if total_length is None or total_length < 1:
483 # Nothing to fixup.
484 indexes_to_map: Iterable[int] = ()
485 else:
486 # Only create "missing" ones.
487 current_max_mapping = (
488 session.query(func.max(TaskInstance.map_index))
489 .filter(
490 TaskInstance.dag_id == self.dag_id,
491 TaskInstance.task_id == self.task_id,
492 TaskInstance.run_id == run_id,
493 )
494 .scalar()
495 )
496 indexes_to_map = range(current_max_mapping + 1, total_length)
498 for index in indexes_to_map:
499 # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
500 ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)
501 self.log.debug("Expanding TIs upserted %s", ti)
502 task_instance_mutation_hook(ti)
503 ti = session.merge(ti)
504 ti.refresh_from_task(self) # session.merge() loses task information.
505 all_expanded_tis.append(ti)
507 # Coerce the None case to 0 -- these two are almost treated identically,
508 # except the unmapped ti (if exists) is marked to different states.
509 total_expanded_ti_count = total_length or 0
511 # Any (old) task instances with inapplicable indexes (>= the total
512 # number we need) are set to "REMOVED".
513 query = session.query(TaskInstance).filter(
514 TaskInstance.dag_id == self.dag_id,
515 TaskInstance.task_id == self.task_id,
516 TaskInstance.run_id == run_id,
517 TaskInstance.map_index >= total_expanded_ti_count,
518 )
519 to_update = with_row_locks(query, of=TaskInstance, session=session, **skip_locked(session=session))
520 for ti in to_update:
521 ti.state = TaskInstanceState.REMOVED
522 session.flush()
523 return all_expanded_tis, total_expanded_ti_count - 1
525 def render_template_fields(
526 self,
527 context: Context,
528 jinja_env: jinja2.Environment | None = None,
529 ) -> None:
530 """Template all attributes listed in *self.template_fields*.
532 If the operator is mapped, this should return the unmapped, fully
533 rendered, and map-expanded operator. The mapped operator should not be
534 modified. However, *context* may be modified in-place to reference the
535 unmapped operator for template rendering.
537 If the operator is not mapped, this should modify the operator in-place.
538 """
539 raise NotImplementedError()
541 def _render(self, template, context, dag: DAG | None = None):
542 if dag is None:
543 dag = self.get_dag()
544 return super()._render(template, context, dag=dag)
546 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
547 """Get the template environment for rendering templates."""
548 if dag is None:
549 dag = self.get_dag()
550 return super().get_template_env(dag=dag)
552 @provide_session
553 def _do_render_template_fields(
554 self,
555 parent: Any,
556 template_fields: Iterable[str],
557 context: Context,
558 jinja_env: jinja2.Environment,
559 seen_oids: set[int],
560 *,
561 session: Session = NEW_SESSION,
562 ) -> None:
563 """Override the base to use custom error logging."""
564 for attr_name in template_fields:
565 try:
566 value = getattr(parent, attr_name)
567 except AttributeError:
568 raise AttributeError(
569 f"{attr_name!r} is configured as a template field "
570 f"but {parent.task_type} does not have this attribute."
571 )
573 try:
574 if not value:
575 continue
576 except Exception:
577 # This may happen if the templated field points to a class which does not support `__bool__`,
578 # such as Pandas DataFrames:
579 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
580 self.log.info(
581 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
582 type(value).__name__,
583 self.task_id,
584 attr_name,
585 )
586 # We may still want to render custom classes which do not support __bool__
587 pass
589 try:
590 rendered_content = self.render_template(
591 value,
592 context,
593 jinja_env,
594 seen_oids,
595 )
596 except Exception:
597 value_masked = redact(name=attr_name, value=value)
598 self.log.exception(
599 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
600 self.task_id,
601 attr_name,
602 value_masked,
603 )
604 raise
605 else:
606 setattr(parent, attr_name, rendered_content)