Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/abstractoperator.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

353 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import datetime 

21import inspect 

22from abc import abstractproperty 

23from functools import cached_property 

24from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence 

25 

26import methodtools 

27from sqlalchemy import select 

28 

29from airflow.configuration import conf 

30from airflow.exceptions import AirflowException 

31from airflow.models.expandinput import NotFullyPopulated 

32from airflow.models.taskmixin import DAGNode, DependencyMixin 

33from airflow.template.templater import Templater 

34from airflow.utils.context import Context 

35from airflow.utils.db import exists_query 

36from airflow.utils.log.secrets_masker import redact 

37from airflow.utils.setup_teardown import SetupTeardownContext 

38from airflow.utils.sqlalchemy import with_row_locks 

39from airflow.utils.state import State, TaskInstanceState 

40from airflow.utils.task_group import MappedTaskGroup 

41from airflow.utils.trigger_rule import TriggerRule 

42from airflow.utils.types import NOTSET, ArgNotSet 

43from airflow.utils.weight_rule import WeightRule 

44 

45TaskStateChangeCallback = Callable[[Context], None] 

46 

47if TYPE_CHECKING: 

48 import jinja2 # Slow import. 

49 from sqlalchemy.orm import Session 

50 

51 from airflow.models.baseoperator import BaseOperator 

52 from airflow.models.baseoperatorlink import BaseOperatorLink 

53 from airflow.models.dag import DAG 

54 from airflow.models.mappedoperator import MappedOperator 

55 from airflow.models.operator import Operator 

56 from airflow.models.taskinstance import TaskInstance 

57 from airflow.task.priority_strategy import PriorityWeightStrategy 

58 from airflow.utils.task_group import TaskGroup 

59 

60DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") 

61DEFAULT_POOL_SLOTS: int = 1 

62DEFAULT_PRIORITY_WEIGHT: int = 1 

63DEFAULT_EXECUTOR: str | None = None 

64DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") 

65DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = conf.getboolean( 

66 "scheduler", "ignore_first_depends_on_past_by_default" 

67) 

68DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False 

69DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) 

70DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( 

71 seconds=conf.getint("core", "default_task_retry_delay", fallback=300) 

72) 

73MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) 

74 

75DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( 

76 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) 

77) 

78DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS 

79DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( 

80 "core", "default_task_execution_timeout" 

81) 

82 

83 

84class NotMapped(Exception): 

85 """Raise if a task is neither mapped nor has any parent mapped groups.""" 

86 

87 

88class AbstractOperator(Templater, DAGNode): 

89 """Common implementation for operators, including unmapped and mapped. 

90 

91 This base class is more about sharing implementations, not defining a common 

92 interface. Unfortunately it's difficult to use this as the common base class 

93 for typing due to BaseOperator carrying too much historical baggage. 

94 

95 The union type ``from airflow.models.operator import Operator`` is easier 

96 to use for typing purposes. 

97 

98 :meta private: 

99 """ 

100 

101 operator_class: type[BaseOperator] | dict[str, Any] 

102 

103 weight_rule: PriorityWeightStrategy 

104 priority_weight: int 

105 

106 # Defines the operator level extra links. 

107 operator_extra_links: Collection[BaseOperatorLink] 

108 

109 owner: str 

110 task_id: str 

111 

112 outlets: list 

113 inlets: list 

114 trigger_rule: TriggerRule 

115 _needs_expansion: bool | None = None 

116 _on_failure_fail_dagrun = False 

117 

118 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( 

119 ( 

120 "log", 

121 "dag", # We show dag_id, don't need to show this too 

122 "node_id", # Duplicates task_id 

123 "task_group", # Doesn't have a useful repr, no point showing in UI 

124 "inherits_from_empty_operator", # impl detail 

125 "start_trigger", 

126 "next_method", 

127 # For compatibility with TG, for operators these are just the current task, no point showing 

128 "roots", 

129 "leaves", 

130 # These lists are already shown via *_task_ids 

131 "upstream_list", 

132 "downstream_list", 

133 # Not useful, implementation detail, already shown elsewhere 

134 "global_operator_extra_link_dict", 

135 "operator_extra_link_dict", 

136 ) 

137 ) 

138 

139 def get_dag(self) -> DAG | None: 

140 raise NotImplementedError() 

141 

142 @property 

143 def task_type(self) -> str: 

144 raise NotImplementedError() 

145 

146 @property 

147 def operator_name(self) -> str: 

148 raise NotImplementedError() 

149 

150 @property 

151 def inherits_from_empty_operator(self) -> bool: 

152 raise NotImplementedError() 

153 

154 @property 

155 def dag_id(self) -> str: 

156 """Returns dag id if it has one or an adhoc + owner.""" 

157 dag = self.get_dag() 

158 if dag: 

159 return dag.dag_id 

160 return f"adhoc_{self.owner}" 

161 

162 @property 

163 def node_id(self) -> str: 

164 return self.task_id 

165 

166 @abstractproperty 

167 def task_display_name(self) -> str: ... 

168 

169 @property 

170 def label(self) -> str | None: 

171 if self.task_display_name and self.task_display_name != self.task_id: 

172 return self.task_display_name 

173 # Prefix handling if no display is given is cloned from taskmixin for compatibility 

174 tg = self.task_group 

175 if tg and tg.node_id and tg.prefix_group_id: 

176 # "task_group_id.task_id" -> "task_id" 

177 return self.task_id[len(tg.node_id) + 1 :] 

178 return self.task_id 

179 

180 @property 

181 def is_setup(self) -> bool: 

182 raise NotImplementedError() 

183 

184 @is_setup.setter 

185 def is_setup(self, value: bool) -> None: 

186 raise NotImplementedError() 

187 

188 @property 

189 def is_teardown(self) -> bool: 

190 raise NotImplementedError() 

191 

192 @is_teardown.setter 

193 def is_teardown(self, value: bool) -> None: 

194 raise NotImplementedError() 

195 

196 @property 

197 def on_failure_fail_dagrun(self): 

198 """ 

199 Whether the operator should fail the dagrun on failure. 

200 

201 :meta private: 

202 """ 

203 return self._on_failure_fail_dagrun 

204 

205 @on_failure_fail_dagrun.setter 

206 def on_failure_fail_dagrun(self, value): 

207 """ 

208 Setter for on_failure_fail_dagrun property. 

209 

210 :meta private: 

211 """ 

212 if value is True and self.is_teardown is not True: 

213 raise ValueError( 

214 f"Cannot set task on_failure_fail_dagrun for " 

215 f"'{self.task_id}' because it is not a teardown task." 

216 ) 

217 self._on_failure_fail_dagrun = value 

218 

219 def as_setup(self): 

220 self.is_setup = True 

221 return self 

222 

223 def as_teardown( 

224 self, 

225 *, 

226 setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET, 

227 on_failure_fail_dagrun=NOTSET, 

228 ): 

229 self.is_teardown = True 

230 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS 

231 if on_failure_fail_dagrun is not NOTSET: 

232 self.on_failure_fail_dagrun = on_failure_fail_dagrun 

233 if not isinstance(setups, ArgNotSet): 

234 setups = [setups] if isinstance(setups, DependencyMixin) else setups 

235 for s in setups: 

236 s.is_setup = True 

237 s >> self 

238 return self 

239 

240 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: 

241 """Get direct relative IDs to the current task, upstream or downstream.""" 

242 if upstream: 

243 return self.upstream_task_ids 

244 return self.downstream_task_ids 

245 

246 def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: 

247 """Get a flat set of relative IDs, upstream or downstream. 

248 

249 Will recurse each relative found in the direction specified. 

250 

251 :param upstream: Whether to look for upstream or downstream relatives. 

252 """ 

253 dag = self.get_dag() 

254 if not dag: 

255 return set() 

256 

257 relatives: set[str] = set() 

258 

259 # This is intentionally implemented as a loop, instead of calling 

260 # get_direct_relative_ids() recursively, since Python has significant 

261 # limitation on stack level, and a recursive implementation can blow up 

262 # if a DAG contains very long routes. 

263 task_ids_to_trace = self.get_direct_relative_ids(upstream) 

264 while task_ids_to_trace: 

265 task_ids_to_trace_next: set[str] = set() 

266 for task_id in task_ids_to_trace: 

267 if task_id in relatives: 

268 continue 

269 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) 

270 relatives.add(task_id) 

271 task_ids_to_trace = task_ids_to_trace_next 

272 

273 return relatives 

274 

275 def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]: 

276 """Get a flat list of relatives, either upstream or downstream.""" 

277 dag = self.get_dag() 

278 if not dag: 

279 return set() 

280 return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)] 

281 

282 def get_upstreams_follow_setups(self) -> Iterable[Operator]: 

283 """All upstreams and, for each upstream setup, its respective teardowns.""" 

284 for task in self.get_flat_relatives(upstream=True): 

285 yield task 

286 if task.is_setup: 

287 for t in task.downstream_list: 

288 if t.is_teardown and t != self: 

289 yield t 

290 

291 def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]: 

292 """ 

293 Only *relevant* upstream setups and their teardowns. 

294 

295 This method is meant to be used when we are clearing the task (non-upstream) and we need 

296 to add in the *relevant* setups and their teardowns. 

297 

298 Relevant in this case means, the setup has a teardown that is downstream of ``self``, 

299 or the setup has no teardowns. 

300 """ 

301 downstream_teardown_ids = { 

302 x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown 

303 } 

304 for task in self.get_flat_relatives(upstream=True): 

305 if not task.is_setup: 

306 continue 

307 has_no_teardowns = not any(True for x in task.downstream_list if x.is_teardown) 

308 # if task has no teardowns or has teardowns downstream of self 

309 if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids): 

310 yield task 

311 for t in task.downstream_list: 

312 if t.is_teardown and t != self: 

313 yield t 

314 

315 def get_upstreams_only_setups(self) -> Iterable[Operator]: 

316 """ 

317 Return relevant upstream setups. 

318 

319 This method is meant to be used when we are checking task dependencies where we need 

320 to wait for all the upstream setups to complete before we can run the task. 

321 """ 

322 for task in self.get_upstreams_only_setups_and_teardowns(): 

323 if task.is_setup: 

324 yield task 

325 

326 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

327 """Return mapped nodes that are direct dependencies of the current task. 

328 

329 For now, this walks the entire DAG to find mapped nodes that has this 

330 current task as an upstream. We cannot use ``downstream_list`` since it 

331 only contains operators, not task groups. In the future, we should 

332 provide a way to record an DAG node's all downstream nodes instead. 

333 

334 Note that this does not guarantee the returned tasks actually use the 

335 current task for task mapping, but only checks those task are mapped 

336 operators, and are downstreams of the current task. 

337 

338 To get a list of tasks that uses the current task for task mapping, use 

339 :meth:`iter_mapped_dependants` instead. 

340 """ 

341 from airflow.models.mappedoperator import MappedOperator 

342 from airflow.utils.task_group import TaskGroup 

343 

344 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: 

345 """Recursively walk children in a task group. 

346 

347 This yields all direct children (including both tasks and task 

348 groups), and all children of any task groups. 

349 """ 

350 for key, child in group.children.items(): 

351 yield key, child 

352 if isinstance(child, TaskGroup): 

353 yield from _walk_group(child) 

354 

355 dag = self.get_dag() 

356 if not dag: 

357 raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG") 

358 for key, child in _walk_group(dag.task_group): 

359 if key == self.node_id: 

360 continue 

361 if not isinstance(child, (MappedOperator, MappedTaskGroup)): 

362 continue 

363 if self.node_id in child.upstream_task_ids: 

364 yield child 

365 

366 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

367 """Return mapped nodes that depend on the current task the expansion. 

368 

369 For now, this walks the entire DAG to find mapped nodes that has this 

370 current task as an upstream. We cannot use ``downstream_list`` since it 

371 only contains operators, not task groups. In the future, we should 

372 provide a way to record an DAG node's all downstream nodes instead. 

373 """ 

374 return ( 

375 downstream 

376 for downstream in self._iter_all_mapped_downstreams() 

377 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) 

378 ) 

379 

380 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

381 """Return mapped task groups this task belongs to. 

382 

383 Groups are returned from the innermost to the outmost. 

384 

385 :meta private: 

386 """ 

387 if (group := self.task_group) is None: 

388 return 

389 yield from group.iter_mapped_task_groups() 

390 

391 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: 

392 """Get the mapped task group "closest" to this task in the DAG. 

393 

394 :meta private: 

395 """ 

396 return next(self.iter_mapped_task_groups(), None) 

397 

398 def get_needs_expansion(self) -> bool: 

399 """ 

400 Return true if the task is MappedOperator or is in a mapped task group. 

401 

402 :meta private: 

403 """ 

404 if self._needs_expansion is None: 

405 if self.get_closest_mapped_task_group() is not None: 

406 self._needs_expansion = True 

407 else: 

408 self._needs_expansion = False 

409 return self._needs_expansion 

410 

411 def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator: 

412 """Get the "normal" operator from current abstract operator. 

413 

414 MappedOperator uses this to unmap itself based on the map index. A non- 

415 mapped operator (i.e. BaseOperator subclass) simply returns itself. 

416 

417 :meta private: 

418 """ 

419 raise NotImplementedError() 

420 

421 @property 

422 def priority_weight_total(self) -> int: 

423 """ 

424 Total priority weight for the task. It might include all upstream or downstream tasks. 

425 

426 Depending on the weight rule: 

427 

428 - WeightRule.ABSOLUTE - only own weight 

429 - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks 

430 - WeightRule.UPSTREAM - adds priority weight of all upstream tasks 

431 """ 

432 from airflow.task.priority_strategy import ( 

433 _AbsolutePriorityWeightStrategy, 

434 _DownstreamPriorityWeightStrategy, 

435 _UpstreamPriorityWeightStrategy, 

436 ) 

437 

438 if type(self.weight_rule) == _AbsolutePriorityWeightStrategy: 

439 return self.priority_weight 

440 elif type(self.weight_rule) == _DownstreamPriorityWeightStrategy: 

441 upstream = False 

442 elif type(self.weight_rule) == _UpstreamPriorityWeightStrategy: 

443 upstream = True 

444 else: 

445 upstream = False 

446 dag = self.get_dag() 

447 if dag is None: 

448 return self.priority_weight 

449 return self.priority_weight + sum( 

450 dag.task_dict[task_id].priority_weight 

451 for task_id in self.get_flat_relative_ids(upstream=upstream) 

452 ) 

453 

454 @cached_property 

455 def operator_extra_link_dict(self) -> dict[str, Any]: 

456 """Returns dictionary of all extra links for the operator.""" 

457 op_extra_links_from_plugin: dict[str, Any] = {} 

458 from airflow import plugins_manager 

459 

460 plugins_manager.initialize_extra_operators_links_plugins() 

461 if plugins_manager.operator_extra_links is None: 

462 raise AirflowException("Can't load operators") 

463 for ope in plugins_manager.operator_extra_links: 

464 if ope.operators and self.operator_class in ope.operators: 

465 op_extra_links_from_plugin.update({ope.name: ope}) 

466 

467 operator_extra_links_all = {link.name: link for link in self.operator_extra_links} 

468 # Extra links defined in Plugins overrides operator links defined in operator 

469 operator_extra_links_all.update(op_extra_links_from_plugin) 

470 

471 return operator_extra_links_all 

472 

473 @cached_property 

474 def global_operator_extra_link_dict(self) -> dict[str, Any]: 

475 """Returns dictionary of all global extra links.""" 

476 from airflow import plugins_manager 

477 

478 plugins_manager.initialize_extra_operators_links_plugins() 

479 if plugins_manager.global_operator_extra_links is None: 

480 raise AirflowException("Can't load operators") 

481 return {link.name: link for link in plugins_manager.global_operator_extra_links} 

482 

483 @cached_property 

484 def extra_links(self) -> list[str]: 

485 return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict)) 

486 

487 def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None: 

488 """For an operator, gets the URLs that the ``extra_links`` entry points to. 

489 

490 :meta private: 

491 

492 :raise ValueError: The error message of a ValueError will be passed on through to 

493 the fronted to show up as a tooltip on the disabled link. 

494 :param ti: The TaskInstance for the URL being searched for. 

495 :param link_name: The name of the link we're looking for the URL for. Should be 

496 one of the options specified in ``extra_links``. 

497 """ 

498 link: BaseOperatorLink | None = self.operator_extra_link_dict.get(link_name) 

499 if not link: 

500 link = self.global_operator_extra_link_dict.get(link_name) 

501 if not link: 

502 return None 

503 

504 parameters = inspect.signature(link.get_link).parameters 

505 old_signature = all(name != "ti_key" for name, p in parameters.items() if p.kind != p.VAR_KEYWORD) 

506 

507 if old_signature: 

508 return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc] 

509 return link.get_link(self.unmap(None), ti_key=ti.key) 

510 

511 @methodtools.lru_cache(maxsize=None) 

512 def get_parse_time_mapped_ti_count(self) -> int: 

513 """ 

514 Return the number of mapped task instances that can be created on DAG run creation. 

515 

516 This only considers literal mapped arguments, and would return *None* 

517 when any non-literal values are used for mapping. 

518 

519 :raise NotFullyPopulated: If non-literal mapped arguments are encountered. 

520 :raise NotMapped: If the operator is neither mapped, nor has any parent 

521 mapped task groups. 

522 :return: Total number of mapped TIs this task should have. 

523 """ 

524 group = self.get_closest_mapped_task_group() 

525 if group is None: 

526 raise NotMapped 

527 return group.get_parse_time_mapped_ti_count() 

528 

529 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: 

530 """ 

531 Return the number of mapped TaskInstances that can be created at run time. 

532 

533 This considers both literal and non-literal mapped arguments, and the 

534 result is therefore available when all depended tasks have finished. The 

535 return value should be identical to ``parse_time_mapped_ti_count`` if 

536 all mapped arguments are literal. 

537 

538 :raise NotFullyPopulated: If upstream tasks are not all complete yet. 

539 :raise NotMapped: If the operator is neither mapped, nor has any parent 

540 mapped task groups. 

541 :return: Total number of mapped TIs this task should have. 

542 """ 

543 group = self.get_closest_mapped_task_group() 

544 if group is None: 

545 raise NotMapped 

546 return group.get_mapped_ti_count(run_id, session=session) 

547 

548 def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]: 

549 """Create the mapped task instances for mapped task. 

550 

551 :raise NotMapped: If this task does not need expansion. 

552 :return: The newly created mapped task instances (if any) in ascending 

553 order by map index, and the maximum map index value. 

554 """ 

555 from sqlalchemy import func, or_ 

556 

557 from airflow.models.baseoperator import BaseOperator 

558 from airflow.models.mappedoperator import MappedOperator 

559 from airflow.models.taskinstance import TaskInstance 

560 from airflow.settings import task_instance_mutation_hook 

561 

562 if not isinstance(self, (BaseOperator, MappedOperator)): 

563 raise RuntimeError(f"cannot expand unrecognized operator type {type(self).__name__}") 

564 

565 try: 

566 total_length: int | None = self.get_mapped_ti_count(run_id, session=session) 

567 except NotFullyPopulated as e: 

568 # It's possible that the upstream tasks are not yet done, but we 

569 # don't have upstream of upstreams in partial DAGs (possible in the 

570 # mini-scheduler), so we ignore this exception. 

571 if not self.dag or not self.dag.partial: 

572 self.log.error( 

573 "Cannot expand %r for run %s; missing upstream values: %s", 

574 self, 

575 run_id, 

576 sorted(e.missing), 

577 ) 

578 total_length = None 

579 

580 state: TaskInstanceState | None = None 

581 unmapped_ti: TaskInstance | None = session.scalars( 

582 select(TaskInstance).where( 

583 TaskInstance.dag_id == self.dag_id, 

584 TaskInstance.task_id == self.task_id, 

585 TaskInstance.run_id == run_id, 

586 TaskInstance.map_index == -1, 

587 or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), 

588 ) 

589 ).one_or_none() 

590 

591 all_expanded_tis: list[TaskInstance] = [] 

592 

593 if unmapped_ti: 

594 # The unmapped task instance still exists and is unfinished, i.e. we 

595 # haven't tried to run it before. 

596 if total_length is None: 

597 # If the DAG is partial, it's likely that the upstream tasks 

598 # are not done yet, so the task can't fail yet. 

599 if not self.dag or not self.dag.partial: 

600 unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED 

601 elif total_length < 1: 

602 # If the upstream maps this to a zero-length value, simply mark 

603 # the unmapped task instance as SKIPPED (if needed). 

604 self.log.info( 

605 "Marking %s as SKIPPED since the map has %d values to expand", 

606 unmapped_ti, 

607 total_length, 

608 ) 

609 unmapped_ti.state = TaskInstanceState.SKIPPED 

610 else: 

611 zero_index_ti_exists = exists_query( 

612 TaskInstance.dag_id == self.dag_id, 

613 TaskInstance.task_id == self.task_id, 

614 TaskInstance.run_id == run_id, 

615 TaskInstance.map_index == 0, 

616 session=session, 

617 ) 

618 if not zero_index_ti_exists: 

619 # Otherwise convert this into the first mapped index, and create 

620 # TaskInstance for other indexes. 

621 unmapped_ti.map_index = 0 

622 self.log.debug("Updated in place to become %s", unmapped_ti) 

623 all_expanded_tis.append(unmapped_ti) 

624 session.flush() 

625 else: 

626 self.log.debug("Deleting the original task instance: %s", unmapped_ti) 

627 session.delete(unmapped_ti) 

628 state = unmapped_ti.state 

629 

630 if total_length is None or total_length < 1: 

631 # Nothing to fixup. 

632 indexes_to_map: Iterable[int] = () 

633 else: 

634 # Only create "missing" ones. 

635 current_max_mapping = session.scalar( 

636 select(func.max(TaskInstance.map_index)).where( 

637 TaskInstance.dag_id == self.dag_id, 

638 TaskInstance.task_id == self.task_id, 

639 TaskInstance.run_id == run_id, 

640 ) 

641 ) 

642 indexes_to_map = range(current_max_mapping + 1, total_length) 

643 

644 for index in indexes_to_map: 

645 # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. 

646 ti = TaskInstance(self, run_id=run_id, map_index=index, state=state) 

647 self.log.debug("Expanding TIs upserted %s", ti) 

648 task_instance_mutation_hook(ti) 

649 ti = session.merge(ti) 

650 ti.refresh_from_task(self) # session.merge() loses task information. 

651 all_expanded_tis.append(ti) 

652 

653 # Coerce the None case to 0 -- these two are almost treated identically, 

654 # except the unmapped ti (if exists) is marked to different states. 

655 total_expanded_ti_count = total_length or 0 

656 

657 # Any (old) task instances with inapplicable indexes (>= the total 

658 # number we need) are set to "REMOVED". 

659 query = select(TaskInstance).where( 

660 TaskInstance.dag_id == self.dag_id, 

661 TaskInstance.task_id == self.task_id, 

662 TaskInstance.run_id == run_id, 

663 TaskInstance.map_index >= total_expanded_ti_count, 

664 ) 

665 query = with_row_locks(query, of=TaskInstance, session=session, skip_locked=True) 

666 to_update = session.scalars(query) 

667 for ti in to_update: 

668 ti.state = TaskInstanceState.REMOVED 

669 session.flush() 

670 return all_expanded_tis, total_expanded_ti_count - 1 

671 

672 def render_template_fields( 

673 self, 

674 context: Context, 

675 jinja_env: jinja2.Environment | None = None, 

676 ) -> None: 

677 """Template all attributes listed in *self.template_fields*. 

678 

679 If the operator is mapped, this should return the unmapped, fully 

680 rendered, and map-expanded operator. The mapped operator should not be 

681 modified. However, *context* may be modified in-place to reference the 

682 unmapped operator for template rendering. 

683 

684 If the operator is not mapped, this should modify the operator in-place. 

685 """ 

686 raise NotImplementedError() 

687 

688 def _render(self, template, context, dag: DAG | None = None): 

689 if dag is None: 

690 dag = self.get_dag() 

691 return super()._render(template, context, dag=dag) 

692 

693 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: 

694 """Get the template environment for rendering templates.""" 

695 if dag is None: 

696 dag = self.get_dag() 

697 return super().get_template_env(dag=dag) 

698 

699 def _do_render_template_fields( 

700 self, 

701 parent: Any, 

702 template_fields: Iterable[str], 

703 context: Context, 

704 jinja_env: jinja2.Environment, 

705 seen_oids: set[int], 

706 ) -> None: 

707 """Override the base to use custom error logging.""" 

708 for attr_name in template_fields: 

709 try: 

710 value = getattr(parent, attr_name) 

711 except AttributeError: 

712 raise AttributeError( 

713 f"{attr_name!r} is configured as a template field " 

714 f"but {parent.task_type} does not have this attribute." 

715 ) 

716 try: 

717 if not value: 

718 continue 

719 except Exception: 

720 # This may happen if the templated field points to a class which does not support `__bool__`, 

721 # such as Pandas DataFrames: 

722 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 

723 self.log.info( 

724 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", 

725 type(value).__name__, 

726 self.task_id, 

727 attr_name, 

728 ) 

729 # We may still want to render custom classes which do not support __bool__ 

730 pass 

731 

732 try: 

733 rendered_content = self.render_template( 

734 value, 

735 context, 

736 jinja_env, 

737 seen_oids, 

738 ) 

739 except Exception: 

740 value_masked = redact(name=attr_name, value=value) 

741 self.log.exception( 

742 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", 

743 self.task_id, 

744 attr_name, 

745 value_masked, 

746 ) 

747 raise 

748 else: 

749 setattr(parent, attr_name, rendered_content) 

750 

751 def __enter__(self): 

752 if not self.is_setup and not self.is_teardown: 

753 raise AirflowException("Only setup/teardown tasks can be used as context managers.") 

754 SetupTeardownContext.push_setup_teardown_task(self) 

755 return SetupTeardownContext 

756 

757 def __exit__(self, exc_type, exc_val, exc_tb): 

758 SetupTeardownContext.set_work_task_roots_and_leaves()