1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import datetime
21import inspect
22from abc import abstractproperty
23from functools import cached_property
24from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence
25
26import methodtools
27from sqlalchemy import select
28
29from airflow.configuration import conf
30from airflow.exceptions import AirflowException
31from airflow.models.expandinput import NotFullyPopulated
32from airflow.models.taskmixin import DAGNode, DependencyMixin
33from airflow.template.templater import Templater
34from airflow.utils.context import Context
35from airflow.utils.db import exists_query
36from airflow.utils.log.secrets_masker import redact
37from airflow.utils.setup_teardown import SetupTeardownContext
38from airflow.utils.sqlalchemy import with_row_locks
39from airflow.utils.state import State, TaskInstanceState
40from airflow.utils.task_group import MappedTaskGroup
41from airflow.utils.trigger_rule import TriggerRule
42from airflow.utils.types import NOTSET, ArgNotSet
43from airflow.utils.weight_rule import WeightRule
44
45TaskStateChangeCallback = Callable[[Context], None]
46
47if TYPE_CHECKING:
48 import jinja2 # Slow import.
49 from sqlalchemy.orm import Session
50
51 from airflow.models.baseoperator import BaseOperator
52 from airflow.models.baseoperatorlink import BaseOperatorLink
53 from airflow.models.dag import DAG
54 from airflow.models.mappedoperator import MappedOperator
55 from airflow.models.operator import Operator
56 from airflow.models.taskinstance import TaskInstance
57 from airflow.task.priority_strategy import PriorityWeightStrategy
58 from airflow.utils.task_group import TaskGroup
59
60DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
61DEFAULT_POOL_SLOTS: int = 1
62DEFAULT_PRIORITY_WEIGHT: int = 1
63DEFAULT_EXECUTOR: str | None = None
64DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue")
65DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = conf.getboolean(
66 "scheduler", "ignore_first_depends_on_past_by_default"
67)
68DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False
69DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
70DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(
71 seconds=conf.getint("core", "default_task_retry_delay", fallback=300)
72)
73MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60)
74
75DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
76 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
77)
78DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
79DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta(
80 "core", "default_task_execution_timeout"
81)
82
83
84class NotMapped(Exception):
85 """Raise if a task is neither mapped nor has any parent mapped groups."""
86
87
88class AbstractOperator(Templater, DAGNode):
89 """Common implementation for operators, including unmapped and mapped.
90
91 This base class is more about sharing implementations, not defining a common
92 interface. Unfortunately it's difficult to use this as the common base class
93 for typing due to BaseOperator carrying too much historical baggage.
94
95 The union type ``from airflow.models.operator import Operator`` is easier
96 to use for typing purposes.
97
98 :meta private:
99 """
100
101 operator_class: type[BaseOperator] | dict[str, Any]
102
103 weight_rule: PriorityWeightStrategy
104 priority_weight: int
105
106 # Defines the operator level extra links.
107 operator_extra_links: Collection[BaseOperatorLink]
108
109 owner: str
110 task_id: str
111
112 outlets: list
113 inlets: list
114 trigger_rule: TriggerRule
115 _needs_expansion: bool | None = None
116 _on_failure_fail_dagrun = False
117
118 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
119 (
120 "log",
121 "dag", # We show dag_id, don't need to show this too
122 "node_id", # Duplicates task_id
123 "task_group", # Doesn't have a useful repr, no point showing in UI
124 "inherits_from_empty_operator", # impl detail
125 "start_trigger",
126 "next_method",
127 # For compatibility with TG, for operators these are just the current task, no point showing
128 "roots",
129 "leaves",
130 # These lists are already shown via *_task_ids
131 "upstream_list",
132 "downstream_list",
133 # Not useful, implementation detail, already shown elsewhere
134 "global_operator_extra_link_dict",
135 "operator_extra_link_dict",
136 )
137 )
138
139 def get_dag(self) -> DAG | None:
140 raise NotImplementedError()
141
142 @property
143 def task_type(self) -> str:
144 raise NotImplementedError()
145
146 @property
147 def operator_name(self) -> str:
148 raise NotImplementedError()
149
150 @property
151 def inherits_from_empty_operator(self) -> bool:
152 raise NotImplementedError()
153
154 @property
155 def dag_id(self) -> str:
156 """Returns dag id if it has one or an adhoc + owner."""
157 dag = self.get_dag()
158 if dag:
159 return dag.dag_id
160 return f"adhoc_{self.owner}"
161
162 @property
163 def node_id(self) -> str:
164 return self.task_id
165
166 @abstractproperty
167 def task_display_name(self) -> str: ...
168
169 @property
170 def label(self) -> str | None:
171 if self.task_display_name and self.task_display_name != self.task_id:
172 return self.task_display_name
173 # Prefix handling if no display is given is cloned from taskmixin for compatibility
174 tg = self.task_group
175 if tg and tg.node_id and tg.prefix_group_id:
176 # "task_group_id.task_id" -> "task_id"
177 return self.task_id[len(tg.node_id) + 1 :]
178 return self.task_id
179
180 @property
181 def is_setup(self) -> bool:
182 raise NotImplementedError()
183
184 @is_setup.setter
185 def is_setup(self, value: bool) -> None:
186 raise NotImplementedError()
187
188 @property
189 def is_teardown(self) -> bool:
190 raise NotImplementedError()
191
192 @is_teardown.setter
193 def is_teardown(self, value: bool) -> None:
194 raise NotImplementedError()
195
196 @property
197 def on_failure_fail_dagrun(self):
198 """
199 Whether the operator should fail the dagrun on failure.
200
201 :meta private:
202 """
203 return self._on_failure_fail_dagrun
204
205 @on_failure_fail_dagrun.setter
206 def on_failure_fail_dagrun(self, value):
207 """
208 Setter for on_failure_fail_dagrun property.
209
210 :meta private:
211 """
212 if value is True and self.is_teardown is not True:
213 raise ValueError(
214 f"Cannot set task on_failure_fail_dagrun for "
215 f"'{self.task_id}' because it is not a teardown task."
216 )
217 self._on_failure_fail_dagrun = value
218
219 def as_setup(self):
220 self.is_setup = True
221 return self
222
223 def as_teardown(
224 self,
225 *,
226 setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
227 on_failure_fail_dagrun=NOTSET,
228 ):
229 self.is_teardown = True
230 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
231 if on_failure_fail_dagrun is not NOTSET:
232 self.on_failure_fail_dagrun = on_failure_fail_dagrun
233 if not isinstance(setups, ArgNotSet):
234 setups = [setups] if isinstance(setups, DependencyMixin) else setups
235 for s in setups:
236 s.is_setup = True
237 s >> self
238 return self
239
240 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
241 """Get direct relative IDs to the current task, upstream or downstream."""
242 if upstream:
243 return self.upstream_task_ids
244 return self.downstream_task_ids
245
246 def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]:
247 """Get a flat set of relative IDs, upstream or downstream.
248
249 Will recurse each relative found in the direction specified.
250
251 :param upstream: Whether to look for upstream or downstream relatives.
252 """
253 dag = self.get_dag()
254 if not dag:
255 return set()
256
257 relatives: set[str] = set()
258
259 # This is intentionally implemented as a loop, instead of calling
260 # get_direct_relative_ids() recursively, since Python has significant
261 # limitation on stack level, and a recursive implementation can blow up
262 # if a DAG contains very long routes.
263 task_ids_to_trace = self.get_direct_relative_ids(upstream)
264 while task_ids_to_trace:
265 task_ids_to_trace_next: set[str] = set()
266 for task_id in task_ids_to_trace:
267 if task_id in relatives:
268 continue
269 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream))
270 relatives.add(task_id)
271 task_ids_to_trace = task_ids_to_trace_next
272
273 return relatives
274
275 def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]:
276 """Get a flat list of relatives, either upstream or downstream."""
277 dag = self.get_dag()
278 if not dag:
279 return set()
280 return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)]
281
282 def get_upstreams_follow_setups(self) -> Iterable[Operator]:
283 """All upstreams and, for each upstream setup, its respective teardowns."""
284 for task in self.get_flat_relatives(upstream=True):
285 yield task
286 if task.is_setup:
287 for t in task.downstream_list:
288 if t.is_teardown and t != self:
289 yield t
290
291 def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]:
292 """
293 Only *relevant* upstream setups and their teardowns.
294
295 This method is meant to be used when we are clearing the task (non-upstream) and we need
296 to add in the *relevant* setups and their teardowns.
297
298 Relevant in this case means, the setup has a teardown that is downstream of ``self``,
299 or the setup has no teardowns.
300 """
301 downstream_teardown_ids = {
302 x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown
303 }
304 for task in self.get_flat_relatives(upstream=True):
305 if not task.is_setup:
306 continue
307 has_no_teardowns = not any(True for x in task.downstream_list if x.is_teardown)
308 # if task has no teardowns or has teardowns downstream of self
309 if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids):
310 yield task
311 for t in task.downstream_list:
312 if t.is_teardown and t != self:
313 yield t
314
315 def get_upstreams_only_setups(self) -> Iterable[Operator]:
316 """
317 Return relevant upstream setups.
318
319 This method is meant to be used when we are checking task dependencies where we need
320 to wait for all the upstream setups to complete before we can run the task.
321 """
322 for task in self.get_upstreams_only_setups_and_teardowns():
323 if task.is_setup:
324 yield task
325
326 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
327 """Return mapped nodes that are direct dependencies of the current task.
328
329 For now, this walks the entire DAG to find mapped nodes that has this
330 current task as an upstream. We cannot use ``downstream_list`` since it
331 only contains operators, not task groups. In the future, we should
332 provide a way to record an DAG node's all downstream nodes instead.
333
334 Note that this does not guarantee the returned tasks actually use the
335 current task for task mapping, but only checks those task are mapped
336 operators, and are downstreams of the current task.
337
338 To get a list of tasks that uses the current task for task mapping, use
339 :meth:`iter_mapped_dependants` instead.
340 """
341 from airflow.models.mappedoperator import MappedOperator
342 from airflow.utils.task_group import TaskGroup
343
344 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
345 """Recursively walk children in a task group.
346
347 This yields all direct children (including both tasks and task
348 groups), and all children of any task groups.
349 """
350 for key, child in group.children.items():
351 yield key, child
352 if isinstance(child, TaskGroup):
353 yield from _walk_group(child)
354
355 dag = self.get_dag()
356 if not dag:
357 raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG")
358 for key, child in _walk_group(dag.task_group):
359 if key == self.node_id:
360 continue
361 if not isinstance(child, (MappedOperator, MappedTaskGroup)):
362 continue
363 if self.node_id in child.upstream_task_ids:
364 yield child
365
366 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]:
367 """Return mapped nodes that depend on the current task the expansion.
368
369 For now, this walks the entire DAG to find mapped nodes that has this
370 current task as an upstream. We cannot use ``downstream_list`` since it
371 only contains operators, not task groups. In the future, we should
372 provide a way to record an DAG node's all downstream nodes instead.
373 """
374 return (
375 downstream
376 for downstream in self._iter_all_mapped_downstreams()
377 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
378 )
379
380 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
381 """Return mapped task groups this task belongs to.
382
383 Groups are returned from the innermost to the outmost.
384
385 :meta private:
386 """
387 if (group := self.task_group) is None:
388 return
389 yield from group.iter_mapped_task_groups()
390
391 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
392 """Get the mapped task group "closest" to this task in the DAG.
393
394 :meta private:
395 """
396 return next(self.iter_mapped_task_groups(), None)
397
398 def get_needs_expansion(self) -> bool:
399 """
400 Return true if the task is MappedOperator or is in a mapped task group.
401
402 :meta private:
403 """
404 if self._needs_expansion is None:
405 if self.get_closest_mapped_task_group() is not None:
406 self._needs_expansion = True
407 else:
408 self._needs_expansion = False
409 return self._needs_expansion
410
411 def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
412 """Get the "normal" operator from current abstract operator.
413
414 MappedOperator uses this to unmap itself based on the map index. A non-
415 mapped operator (i.e. BaseOperator subclass) simply returns itself.
416
417 :meta private:
418 """
419 raise NotImplementedError()
420
421 @property
422 def priority_weight_total(self) -> int:
423 """
424 Total priority weight for the task. It might include all upstream or downstream tasks.
425
426 Depending on the weight rule:
427
428 - WeightRule.ABSOLUTE - only own weight
429 - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks
430 - WeightRule.UPSTREAM - adds priority weight of all upstream tasks
431 """
432 from airflow.task.priority_strategy import (
433 _AbsolutePriorityWeightStrategy,
434 _DownstreamPriorityWeightStrategy,
435 _UpstreamPriorityWeightStrategy,
436 )
437
438 if type(self.weight_rule) == _AbsolutePriorityWeightStrategy:
439 return self.priority_weight
440 elif type(self.weight_rule) == _DownstreamPriorityWeightStrategy:
441 upstream = False
442 elif type(self.weight_rule) == _UpstreamPriorityWeightStrategy:
443 upstream = True
444 else:
445 upstream = False
446 dag = self.get_dag()
447 if dag is None:
448 return self.priority_weight
449 return self.priority_weight + sum(
450 dag.task_dict[task_id].priority_weight
451 for task_id in self.get_flat_relative_ids(upstream=upstream)
452 )
453
454 @cached_property
455 def operator_extra_link_dict(self) -> dict[str, Any]:
456 """Returns dictionary of all extra links for the operator."""
457 op_extra_links_from_plugin: dict[str, Any] = {}
458 from airflow import plugins_manager
459
460 plugins_manager.initialize_extra_operators_links_plugins()
461 if plugins_manager.operator_extra_links is None:
462 raise AirflowException("Can't load operators")
463 for ope in plugins_manager.operator_extra_links:
464 if ope.operators and self.operator_class in ope.operators:
465 op_extra_links_from_plugin.update({ope.name: ope})
466
467 operator_extra_links_all = {link.name: link for link in self.operator_extra_links}
468 # Extra links defined in Plugins overrides operator links defined in operator
469 operator_extra_links_all.update(op_extra_links_from_plugin)
470
471 return operator_extra_links_all
472
473 @cached_property
474 def global_operator_extra_link_dict(self) -> dict[str, Any]:
475 """Returns dictionary of all global extra links."""
476 from airflow import plugins_manager
477
478 plugins_manager.initialize_extra_operators_links_plugins()
479 if plugins_manager.global_operator_extra_links is None:
480 raise AirflowException("Can't load operators")
481 return {link.name: link for link in plugins_manager.global_operator_extra_links}
482
483 @cached_property
484 def extra_links(self) -> list[str]:
485 return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))
486
487 def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
488 """For an operator, gets the URLs that the ``extra_links`` entry points to.
489
490 :meta private:
491
492 :raise ValueError: The error message of a ValueError will be passed on through to
493 the fronted to show up as a tooltip on the disabled link.
494 :param ti: The TaskInstance for the URL being searched for.
495 :param link_name: The name of the link we're looking for the URL for. Should be
496 one of the options specified in ``extra_links``.
497 """
498 link: BaseOperatorLink | None = self.operator_extra_link_dict.get(link_name)
499 if not link:
500 link = self.global_operator_extra_link_dict.get(link_name)
501 if not link:
502 return None
503
504 parameters = inspect.signature(link.get_link).parameters
505 old_signature = all(name != "ti_key" for name, p in parameters.items() if p.kind != p.VAR_KEYWORD)
506
507 if old_signature:
508 return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc]
509 return link.get_link(self.unmap(None), ti_key=ti.key)
510
511 @methodtools.lru_cache(maxsize=None)
512 def get_parse_time_mapped_ti_count(self) -> int:
513 """
514 Return the number of mapped task instances that can be created on DAG run creation.
515
516 This only considers literal mapped arguments, and would return *None*
517 when any non-literal values are used for mapping.
518
519 :raise NotFullyPopulated: If non-literal mapped arguments are encountered.
520 :raise NotMapped: If the operator is neither mapped, nor has any parent
521 mapped task groups.
522 :return: Total number of mapped TIs this task should have.
523 """
524 group = self.get_closest_mapped_task_group()
525 if group is None:
526 raise NotMapped
527 return group.get_parse_time_mapped_ti_count()
528
529 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
530 """
531 Return the number of mapped TaskInstances that can be created at run time.
532
533 This considers both literal and non-literal mapped arguments, and the
534 result is therefore available when all depended tasks have finished. The
535 return value should be identical to ``parse_time_mapped_ti_count`` if
536 all mapped arguments are literal.
537
538 :raise NotFullyPopulated: If upstream tasks are not all complete yet.
539 :raise NotMapped: If the operator is neither mapped, nor has any parent
540 mapped task groups.
541 :return: Total number of mapped TIs this task should have.
542 """
543 group = self.get_closest_mapped_task_group()
544 if group is None:
545 raise NotMapped
546 return group.get_mapped_ti_count(run_id, session=session)
547
548 def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
549 """Create the mapped task instances for mapped task.
550
551 :raise NotMapped: If this task does not need expansion.
552 :return: The newly created mapped task instances (if any) in ascending
553 order by map index, and the maximum map index value.
554 """
555 from sqlalchemy import func, or_
556
557 from airflow.models.baseoperator import BaseOperator
558 from airflow.models.mappedoperator import MappedOperator
559 from airflow.models.taskinstance import TaskInstance
560 from airflow.settings import task_instance_mutation_hook
561
562 if not isinstance(self, (BaseOperator, MappedOperator)):
563 raise RuntimeError(f"cannot expand unrecognized operator type {type(self).__name__}")
564
565 try:
566 total_length: int | None = self.get_mapped_ti_count(run_id, session=session)
567 except NotFullyPopulated as e:
568 # It's possible that the upstream tasks are not yet done, but we
569 # don't have upstream of upstreams in partial DAGs (possible in the
570 # mini-scheduler), so we ignore this exception.
571 if not self.dag or not self.dag.partial:
572 self.log.error(
573 "Cannot expand %r for run %s; missing upstream values: %s",
574 self,
575 run_id,
576 sorted(e.missing),
577 )
578 total_length = None
579
580 state: TaskInstanceState | None = None
581 unmapped_ti: TaskInstance | None = session.scalars(
582 select(TaskInstance).where(
583 TaskInstance.dag_id == self.dag_id,
584 TaskInstance.task_id == self.task_id,
585 TaskInstance.run_id == run_id,
586 TaskInstance.map_index == -1,
587 or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)),
588 )
589 ).one_or_none()
590
591 all_expanded_tis: list[TaskInstance] = []
592
593 if unmapped_ti:
594 # The unmapped task instance still exists and is unfinished, i.e. we
595 # haven't tried to run it before.
596 if total_length is None:
597 # If the DAG is partial, it's likely that the upstream tasks
598 # are not done yet, so the task can't fail yet.
599 if not self.dag or not self.dag.partial:
600 unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
601 elif total_length < 1:
602 # If the upstream maps this to a zero-length value, simply mark
603 # the unmapped task instance as SKIPPED (if needed).
604 self.log.info(
605 "Marking %s as SKIPPED since the map has %d values to expand",
606 unmapped_ti,
607 total_length,
608 )
609 unmapped_ti.state = TaskInstanceState.SKIPPED
610 else:
611 zero_index_ti_exists = exists_query(
612 TaskInstance.dag_id == self.dag_id,
613 TaskInstance.task_id == self.task_id,
614 TaskInstance.run_id == run_id,
615 TaskInstance.map_index == 0,
616 session=session,
617 )
618 if not zero_index_ti_exists:
619 # Otherwise convert this into the first mapped index, and create
620 # TaskInstance for other indexes.
621 unmapped_ti.map_index = 0
622 self.log.debug("Updated in place to become %s", unmapped_ti)
623 all_expanded_tis.append(unmapped_ti)
624 session.flush()
625 else:
626 self.log.debug("Deleting the original task instance: %s", unmapped_ti)
627 session.delete(unmapped_ti)
628 state = unmapped_ti.state
629
630 if total_length is None or total_length < 1:
631 # Nothing to fixup.
632 indexes_to_map: Iterable[int] = ()
633 else:
634 # Only create "missing" ones.
635 current_max_mapping = session.scalar(
636 select(func.max(TaskInstance.map_index)).where(
637 TaskInstance.dag_id == self.dag_id,
638 TaskInstance.task_id == self.task_id,
639 TaskInstance.run_id == run_id,
640 )
641 )
642 indexes_to_map = range(current_max_mapping + 1, total_length)
643
644 for index in indexes_to_map:
645 # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
646 ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)
647 self.log.debug("Expanding TIs upserted %s", ti)
648 task_instance_mutation_hook(ti)
649 ti = session.merge(ti)
650 ti.refresh_from_task(self) # session.merge() loses task information.
651 all_expanded_tis.append(ti)
652
653 # Coerce the None case to 0 -- these two are almost treated identically,
654 # except the unmapped ti (if exists) is marked to different states.
655 total_expanded_ti_count = total_length or 0
656
657 # Any (old) task instances with inapplicable indexes (>= the total
658 # number we need) are set to "REMOVED".
659 query = select(TaskInstance).where(
660 TaskInstance.dag_id == self.dag_id,
661 TaskInstance.task_id == self.task_id,
662 TaskInstance.run_id == run_id,
663 TaskInstance.map_index >= total_expanded_ti_count,
664 )
665 query = with_row_locks(query, of=TaskInstance, session=session, skip_locked=True)
666 to_update = session.scalars(query)
667 for ti in to_update:
668 ti.state = TaskInstanceState.REMOVED
669 session.flush()
670 return all_expanded_tis, total_expanded_ti_count - 1
671
672 def render_template_fields(
673 self,
674 context: Context,
675 jinja_env: jinja2.Environment | None = None,
676 ) -> None:
677 """Template all attributes listed in *self.template_fields*.
678
679 If the operator is mapped, this should return the unmapped, fully
680 rendered, and map-expanded operator. The mapped operator should not be
681 modified. However, *context* may be modified in-place to reference the
682 unmapped operator for template rendering.
683
684 If the operator is not mapped, this should modify the operator in-place.
685 """
686 raise NotImplementedError()
687
688 def _render(self, template, context, dag: DAG | None = None):
689 if dag is None:
690 dag = self.get_dag()
691 return super()._render(template, context, dag=dag)
692
693 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
694 """Get the template environment for rendering templates."""
695 if dag is None:
696 dag = self.get_dag()
697 return super().get_template_env(dag=dag)
698
699 def _do_render_template_fields(
700 self,
701 parent: Any,
702 template_fields: Iterable[str],
703 context: Context,
704 jinja_env: jinja2.Environment,
705 seen_oids: set[int],
706 ) -> None:
707 """Override the base to use custom error logging."""
708 for attr_name in template_fields:
709 try:
710 value = getattr(parent, attr_name)
711 except AttributeError:
712 raise AttributeError(
713 f"{attr_name!r} is configured as a template field "
714 f"but {parent.task_type} does not have this attribute."
715 )
716 try:
717 if not value:
718 continue
719 except Exception:
720 # This may happen if the templated field points to a class which does not support `__bool__`,
721 # such as Pandas DataFrames:
722 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
723 self.log.info(
724 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
725 type(value).__name__,
726 self.task_id,
727 attr_name,
728 )
729 # We may still want to render custom classes which do not support __bool__
730 pass
731
732 try:
733 rendered_content = self.render_template(
734 value,
735 context,
736 jinja_env,
737 seen_oids,
738 )
739 except Exception:
740 value_masked = redact(name=attr_name, value=value)
741 self.log.exception(
742 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
743 self.task_id,
744 attr_name,
745 value_masked,
746 )
747 raise
748 else:
749 setattr(parent, attr_name, rendered_content)
750
751 def __enter__(self):
752 if not self.is_setup and not self.is_teardown:
753 raise AirflowException("Only setup/teardown tasks can be used as context managers.")
754 SetupTeardownContext.push_setup_teardown_task(self)
755 return SetupTeardownContext
756
757 def __exit__(self, exc_type, exc_val, exc_tb):
758 SetupTeardownContext.set_work_task_roots_and_leaves()