Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/taskgroup.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

323 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""A collection of closely related tasks on the same Dag that should be grouped together visually.""" 

19 

20from __future__ import annotations 

21 

22import copy 

23import re 

24import weakref 

25from collections import deque 

26from collections.abc import Generator, Iterator, Sequence 

27from typing import TYPE_CHECKING, Any 

28 

29import attrs 

30 

31from airflow.sdk import TriggerRule 

32from airflow.sdk.definitions._internal.node import DAGNode, validate_group_key 

33from airflow.sdk.exceptions import ( 

34 AirflowDagCycleException, 

35 DuplicateTaskIdFound, 

36 TaskAlreadyInTaskGroup, 

37) 

38 

39if TYPE_CHECKING: 

40 from airflow.sdk.api.datamodels._generated import DagAttributeTypes 

41 from airflow.sdk.bases.operator import BaseOperator 

42 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

43 from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput 

44 from airflow.sdk.definitions._internal.mixins import DependencyMixin 

45 from airflow.sdk.definitions.dag import DAG 

46 from airflow.sdk.definitions.edges import EdgeModifier 

47 from airflow.sdk.types import Operator 

48 

49 

50def _default_parent_group() -> TaskGroup | None: 

51 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext 

52 

53 return TaskGroupContext.get_current() 

54 

55 

56def _parent_used_group_ids(tg: TaskGroup) -> set: 

57 if tg.parent_group: 

58 return tg.parent_group.used_group_ids 

59 return set() 

60 

61 

62# This could be achieved with `@dag.default` and make this a method, but for some unknown reason when we do 

63# that it makes Mypy (1.9.0 and 1.13.0 tested) seem to entirely loose track that this is an Attrs class. So 

64# we've gone with this and moved on with our lives, mypy is to much of a dark beast to battle over this. 

65def _default_dag(instance: TaskGroup): 

66 from airflow.sdk.definitions._internal.contextmanager import DagContext 

67 

68 if (pg := instance.parent_group) is not None: 

69 return pg.dag 

70 return DagContext.get_current() 

71 

72 

73# Mypy does not like a lambda for some reason. An explicit annotated function makes it happy. 

74def _validate_group_id(instance, attribute, value: str) -> None: 

75 validate_group_key(value) 

76 

77 

78@attrs.define(repr=False) 

79class TaskGroup(DAGNode): 

80 """ 

81 A collection of tasks. 

82 

83 When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across 

84 all tasks within the group if necessary. 

85 

86 :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict 

87 with group_id of TaskGroup or task_id of tasks in the Dag. Root TaskGroup has group_id 

88 set to None. 

89 :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with 

90 this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed. 

91 Default is True. 

92 :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None 

93 for the root TaskGroup. 

94 :param dag: The Dag that this TaskGroup belongs to. 

95 :param default_args: A dictionary of default parameters to be used 

96 as constructor keyword parameters when initialising operators, 

97 will override default_args defined in the Dag level. 

98 Note that operators have the same hook, and precede those defined 

99 here, meaning that if your dict contains `'depends_on_past': True` 

100 here and `'depends_on_past': False` in the operator's call 

101 `default_args`, the actual value will be `False`. 

102 :param tooltip: The tooltip of the TaskGroup node when displayed in the UI 

103 :param ui_color: The fill color of the TaskGroup node when displayed in the UI 

104 :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI 

105 :param add_suffix_on_collision: If this task group name already exists, 

106 automatically add `__1` etc suffixes 

107 :param group_display_name: If set, this will be the display name for the TaskGroup node in the UI. 

108 """ 

109 

110 _group_id: str | None = attrs.field( 

111 validator=attrs.validators.optional(_validate_group_id), 

112 # This is the default behaviour for attrs, but by specifying this it makes IDEs happier 

113 alias="group_id", 

114 ) 

115 group_display_name: str = attrs.field(default="", validator=attrs.validators.instance_of(str)) 

116 prefix_group_id: bool = attrs.field(default=True) 

117 parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group) 

118 dag: DAG = attrs.field(default=attrs.Factory(_default_dag, takes_self=True)) 

119 default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.deepcopy) 

120 tooltip: str = attrs.field(default="", validator=attrs.validators.instance_of(str)) 

121 children: dict[str, DAGNode] = attrs.field(factory=dict, init=False) 

122 

123 upstream_group_ids: set[str | None] = attrs.field(factory=set, init=False) 

124 downstream_group_ids: set[str | None] = attrs.field(factory=set, init=False) 

125 upstream_task_ids: set[str] = attrs.field(factory=set, init=False) 

126 downstream_task_ids: set[str] = attrs.field(factory=set, init=False) 

127 

128 used_group_ids: set[str] = attrs.field( 

129 default=attrs.Factory(_parent_used_group_ids, takes_self=True), 

130 init=False, 

131 on_setattr=attrs.setters.frozen, 

132 ) 

133 

134 ui_color: str = attrs.field(default="CornflowerBlue", validator=attrs.validators.instance_of(str)) 

135 ui_fgcolor: str = attrs.field(default="#000", validator=attrs.validators.instance_of(str)) 

136 

137 add_suffix_on_collision: bool = False 

138 

139 @dag.validator 

140 def _validate_dag(self, _attr, dag): 

141 if not dag: 

142 raise RuntimeError("TaskGroup can only be used inside a dag") 

143 

144 def __attrs_post_init__(self): 

145 # TODO: If attrs supported init only args we could use that here 

146 # https://github.com/python-attrs/attrs/issues/342 

147 self._check_for_group_id_collisions(self.add_suffix_on_collision) 

148 

149 if self._group_id and not self.parent_group and self.dag: 

150 # Support `tg = TaskGroup(x, dag=dag)` 

151 self.parent_group = self.dag.task_group 

152 

153 if self.parent_group: 

154 self.parent_group.add(self) 

155 if self.parent_group.default_args: 

156 self.default_args = { 

157 **self.parent_group.default_args, 

158 **self.default_args, 

159 } 

160 

161 if self._group_id: 

162 self.used_group_ids.add(self.group_id) 

163 self.used_group_ids.add(self.downstream_join_id) 

164 self.used_group_ids.add(self.upstream_join_id) 

165 

166 def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): 

167 if self._group_id is None: 

168 return 

169 # if given group_id already used assign suffix by incrementing largest used suffix integer 

170 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 

171 if self.group_id in self.used_group_ids: 

172 if not add_suffix_on_collision: 

173 raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG") 

174 base = re.split(r"__\d+$", self._group_id)[0] 

175 suffixes = sorted( 

176 int(re.split(r"^.+__", used_group_id)[1]) 

177 for used_group_id in self.used_group_ids 

178 if used_group_id is not None and re.match(rf"^{base}__\d+$", used_group_id) 

179 ) 

180 if not suffixes: 

181 self._group_id += "__1" 

182 else: 

183 self._group_id = f"{base}__{suffixes[-1] + 1}" 

184 

185 @classmethod 

186 def create_root(cls, dag: DAG) -> TaskGroup: 

187 """Create a root TaskGroup with no group_id or parent.""" 

188 return cls(group_id=None, dag=dag, parent_group=None) 

189 

190 @property 

191 def node_id(self): 

192 return self.group_id 

193 

194 @property 

195 def is_root(self) -> bool: 

196 """Returns True if this TaskGroup is the root TaskGroup. Otherwise False.""" 

197 return not self._group_id 

198 

199 @property 

200 def task_group(self) -> TaskGroup | None: 

201 return self.parent_group 

202 

203 @task_group.setter 

204 def task_group(self, value: TaskGroup | None): 

205 self.parent_group = value 

206 

207 def __iter__(self): 

208 for child in self.children.values(): 

209 yield from self._iter_child(child) 

210 

211 @staticmethod 

212 def _iter_child(child): 

213 """Iterate over the children of this TaskGroup.""" 

214 if isinstance(child, TaskGroup): 

215 yield from child 

216 else: 

217 yield child 

218 

219 def add(self, task: DAGNode) -> DAGNode: 

220 """ 

221 Add a task or TaskGroup to this TaskGroup. 

222 

223 :meta private: 

224 """ 

225 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

226 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext 

227 

228 if TaskGroupContext.active: 

229 if task.task_group and task.task_group != self: 

230 task.task_group.children.pop(task.node_id, None) 

231 task.task_group = self 

232 existing_tg = task.task_group 

233 if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self: 

234 raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id) 

235 

236 # Set the TG first, as setting it might change the return value of node_id! 

237 task.task_group = weakref.proxy(self) 

238 key = task.node_id 

239 

240 if key in self.children: 

241 node_type = "Task" if hasattr(task, "task_id") else "Task Group" 

242 raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the Dag") 

243 

244 if isinstance(task, TaskGroup): 

245 if self.dag: 

246 if task.dag is not None and self.dag is not task.dag: 

247 raise ValueError( 

248 "Cannot mix TaskGroups from different Dags: %s and %s", 

249 self.dag, 

250 task.dag, 

251 ) 

252 task.dag = self.dag 

253 if task.children: 

254 raise ValueError("Cannot add a non-empty TaskGroup") 

255 

256 self.children[key] = task 

257 return task 

258 

259 def _remove(self, task: DAGNode) -> None: 

260 key = task.node_id 

261 

262 if key not in self.children: 

263 raise KeyError(f"Node id {key!r} not part of this task group") 

264 

265 self.used_group_ids.remove(key) 

266 del self.children[key] 

267 

268 @property 

269 def group_id(self) -> str | None: 

270 """group_id of this TaskGroup.""" 

271 if ( 

272 self._group_id 

273 and self.parent_group 

274 and self.parent_group.prefix_group_id 

275 and self.parent_group._group_id 

276 ): 

277 # defer to parent whether it adds a prefix 

278 return self.parent_group.child_id(self._group_id) 

279 return self._group_id 

280 

281 @property 

282 def label(self) -> str | None: 

283 """group_id excluding parent's group_id used as the node label in UI.""" 

284 return self.group_display_name or self._group_id 

285 

286 def update_relative( 

287 self, 

288 other: DependencyMixin, 

289 upstream: bool = True, 

290 edge_modifier: EdgeModifier | None = None, 

291 ) -> None: 

292 """ 

293 Override TaskMixin.update_relative. 

294 

295 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids 

296 accordingly so that we can reduce the number of edges when displaying Graph view. 

297 """ 

298 if isinstance(other, TaskGroup): 

299 # Handles setting relationship between a TaskGroup and another TaskGroup 

300 if upstream: 

301 parent, child = (self, other) 

302 if edge_modifier: 

303 edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id) 

304 else: 

305 parent, child = (other, self) 

306 if edge_modifier: 

307 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id) 

308 

309 parent.upstream_group_ids.add(child.group_id) 

310 child.downstream_group_ids.add(parent.group_id) 

311 else: 

312 # Handles setting relationship between a TaskGroup and a task 

313 for task in other.roots: 

314 if not isinstance(task, DAGNode): 

315 raise RuntimeError( 

316 "Relationships can only be set between TaskGroup " 

317 f"or operators; received {task.__class__.__name__}" 

318 ) 

319 

320 # Do not set a relationship between a TaskGroup and a Label's roots 

321 if self == task: 

322 continue 

323 

324 if upstream: 

325 self.upstream_task_ids.add(task.node_id) 

326 if edge_modifier: 

327 edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id) 

328 else: 

329 self.downstream_task_ids.add(task.node_id) 

330 if edge_modifier: 

331 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id) 

332 

333 def _set_relatives( 

334 self, 

335 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

336 upstream: bool = False, 

337 edge_modifier: EdgeModifier | None = None, 

338 ) -> None: 

339 """ 

340 Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. 

341 

342 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids. 

343 """ 

344 if not isinstance(task_or_task_list, Sequence): 

345 task_or_task_list = [task_or_task_list] 

346 

347 # Helper function to find leaves from a task list or task group 

348 def find_leaves(group_or_task) -> list[Any]: 

349 while group_or_task: 

350 group_or_task_leaves = list(group_or_task.get_leaves()) 

351 if group_or_task_leaves: 

352 return group_or_task_leaves 

353 if group_or_task.upstream_task_ids: 

354 upstream_task_ids_list = list(group_or_task.upstream_task_ids) 

355 return [self.dag.get_task(task_id) for task_id in upstream_task_ids_list] 

356 group_or_task = group_or_task.parent_group 

357 return [] 

358 

359 # Check if the current TaskGroup is empty 

360 leaves = find_leaves(self) 

361 

362 for task_like in task_or_task_list: 

363 self.update_relative(task_like, upstream, edge_modifier=edge_modifier) 

364 

365 if upstream: 

366 for task in self.get_roots(): 

367 task.set_upstream(task_or_task_list) 

368 else: 

369 for task in leaves: # Use the fetched leaves 

370 task.set_downstream(task_or_task_list) 

371 

372 def __enter__(self) -> TaskGroup: 

373 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext 

374 

375 TaskGroupContext.push(self) 

376 return self 

377 

378 def __exit__(self, _type, _value, _tb): 

379 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext 

380 

381 TaskGroupContext.pop() 

382 

383 def has_task(self, task: BaseOperator) -> bool: 

384 """Return True if this TaskGroup or its children TaskGroups contains the given task.""" 

385 if task.task_id in self.children: 

386 return True 

387 

388 return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup)) 

389 

390 @property 

391 def roots(self) -> list[BaseOperator]: 

392 """Required by DependencyMixin.""" 

393 return list(self.get_roots()) 

394 

395 @property 

396 def leaves(self) -> list[BaseOperator]: 

397 """Required by DependencyMixin.""" 

398 return list(self.get_leaves()) 

399 

400 def get_roots(self) -> Generator[BaseOperator, None, None]: 

401 """Return a generator of tasks with no upstream dependencies within the TaskGroup.""" 

402 tasks = list(self) 

403 ids = {x.task_id for x in tasks} 

404 for task in tasks: 

405 if task.upstream_task_ids.isdisjoint(ids): 

406 yield task 

407 

408 def get_leaves(self) -> Generator[BaseOperator, None, None]: 

409 """Return a generator of tasks with no downstream dependencies within the TaskGroup.""" 

410 tasks = list(self) 

411 ids = {x.task_id for x in tasks} 

412 

413 def has_non_teardown_downstream(task, exclude: str): 

414 for down_task in task.downstream_list: 

415 if down_task.task_id == exclude: 

416 continue 

417 if down_task.task_id not in ids: 

418 continue 

419 if not down_task.is_teardown: 

420 return True 

421 return False 

422 

423 def recurse_for_first_non_teardown(task): 

424 for upstream_task in task.upstream_list: 

425 if upstream_task.task_id not in ids: 

426 # upstream task is not in task group 

427 continue 

428 elif upstream_task.is_teardown: 

429 yield from recurse_for_first_non_teardown(upstream_task) 

430 elif task.is_teardown and upstream_task.is_setup: 

431 # don't go through the teardown-to-setup path 

432 continue 

433 # return unless upstream task already has non-teardown downstream in group 

434 elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id): 

435 yield upstream_task 

436 

437 for task in tasks: 

438 if task.downstream_task_ids.isdisjoint(ids): 

439 if not task.is_teardown: 

440 yield task 

441 else: 

442 yield from recurse_for_first_non_teardown(task) 

443 

444 def child_id(self, label): 

445 """Prefix label with group_id if prefix_group_id is True. Otherwise return the label as-is.""" 

446 if self.prefix_group_id: 

447 group_id = self.group_id 

448 if group_id: 

449 return f"{group_id}.{label}" 

450 

451 return label 

452 

453 @property 

454 def upstream_join_id(self) -> str: 

455 """ 

456 Creates a unique ID for upstream dependencies of this TaskGroup. 

457 

458 If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called 

459 upstream_join_id will be created in Graph view to join the outgoing edges from this 

460 TaskGroup to reduce the total number of edges needed to be displayed. 

461 """ 

462 return f"{self.group_id}.upstream_join_id" 

463 

464 @property 

465 def downstream_join_id(self) -> str: 

466 """ 

467 Creates a unique ID for downstream dependencies of this TaskGroup. 

468 

469 If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called 

470 downstream_join_id will be created in Graph view to join the outgoing edges from this 

471 TaskGroup to reduce the total number of edges needed to be displayed. 

472 """ 

473 return f"{self.group_id}.downstream_join_id" 

474 

475 def get_task_group_dict(self) -> dict[str, TaskGroup]: 

476 """Return a flat dictionary of group_id: TaskGroup.""" 

477 task_group_map = {} 

478 

479 def build_map(task_group): 

480 if not isinstance(task_group, TaskGroup): 

481 return 

482 

483 task_group_map[task_group.group_id] = task_group 

484 

485 for child in task_group.children.values(): 

486 build_map(child) 

487 

488 build_map(self) 

489 return task_group_map 

490 

491 def get_child_by_label(self, label: str) -> DAGNode: 

492 """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix).""" 

493 return self.children[self.child_id(label)] 

494 

495 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

496 """Serialize task group; required by DagNode.""" 

497 from airflow.sdk.api.datamodels._generated import DagAttributeTypes 

498 from airflow.serialization.serialized_objects import TaskGroupSerialization 

499 

500 return ( 

501 DagAttributeTypes.TASK_GROUP, 

502 TaskGroupSerialization.serialize_task_group(self), 

503 ) 

504 

505 def hierarchical_alphabetical_sort(self): 

506 """ 

507 Sort children in hierarchical alphabetical order. 

508 

509 - groups in alphabetical order first 

510 - tasks in alphabetical order after them. 

511 

512 :return: list of tasks in hierarchical alphabetical order 

513 """ 

514 return sorted( 

515 self.children.values(), 

516 key=lambda node: (not isinstance(node, TaskGroup), node.node_id), 

517 ) 

518 

519 def topological_sort(self): 

520 """ 

521 Sorts children in topographical order, such that a task comes after any of its upstream dependencies. 

522 

523 :return: list of tasks in topological order 

524 """ 

525 # This uses a modified version of Kahn's Topological Sort algorithm to 

526 # not have to pre-compute the "in-degree" of the nodes. 

527 graph_unsorted = copy.copy(self.children) 

528 

529 graph_sorted: list[DAGNode] = [] 

530 

531 # special case 

532 if not self.children: 

533 return graph_sorted 

534 

535 # Run until the unsorted graph is empty. 

536 while graph_unsorted: 

537 # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain 

538 # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the 

539 # pair from the unsorted graph, and append it to the sorted graph. Note here that by using 

540 # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify 

541 # the unsorted graph as we move through it. 

542 # 

543 # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved 

544 # during each pass through the graph. If not, we need to exit as the graph therefore can't be 

545 # sorted. 

546 acyclic = False 

547 for node in list(graph_unsorted.values()): 

548 for edge in node.upstream_list: 

549 if edge.node_id in graph_unsorted: 

550 break 

551 # Check for task's group is a child (or grand child) of this TG, 

552 tg = edge.task_group 

553 while tg: 

554 if tg.node_id in graph_unsorted: 

555 break 

556 tg = tg.parent_group 

557 

558 if tg: 

559 # We are already going to visit that TG 

560 break 

561 else: 

562 acyclic = True 

563 del graph_unsorted[node.node_id] 

564 graph_sorted.append(node) 

565 

566 if not acyclic: 

567 raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}") 

568 

569 return graph_sorted 

570 

571 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

572 """ 

573 Return mapped task groups in the hierarchy. 

574 

575 Groups are returned from the closest to the outmost. If *self* is a 

576 mapped task group, it is returned first. 

577 

578 :meta private: 

579 """ 

580 group: TaskGroup | None = self 

581 while group is not None: 

582 if isinstance(group, MappedTaskGroup): 

583 yield group 

584 group = group.parent_group 

585 

586 def iter_tasks(self) -> Iterator[AbstractOperator]: 

587 """Return an iterator of the child tasks.""" 

588 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

589 

590 groups_to_visit = deque([self]) 

591 

592 while groups_to_visit: 

593 visiting = groups_to_visit.popleft() 

594 

595 for child in visiting.children.values(): 

596 if isinstance(child, AbstractOperator): 

597 yield child 

598 elif isinstance(child, TaskGroup): 

599 groups_to_visit.append(child) 

600 else: 

601 raise ValueError( 

602 f"Encountered a DAGNode that is not a TaskGroup or an " 

603 f"AbstractOperator: {type(child).__module__}.{type(child)}" 

604 ) 

605 

606 

607@attrs.define(kw_only=True, repr=False) 

608class MappedTaskGroup(TaskGroup): 

609 """ 

610 A mapped task group. 

611 

612 This doesn't really do anything special, just holds some additional metadata 

613 for expansion later. 

614 

615 Don't instantiate this class directly; call *expand* or *expand_kwargs* on 

616 a ``@task_group`` function instead. 

617 """ 

618 

619 _expand_input: DictOfListsExpandInput | ListOfDictsExpandInput = attrs.field(alias="expand_input") 

620 

621 def __iter__(self): 

622 for child in self.children.values(): 

623 if getattr(child, "trigger_rule", None) == TriggerRule.ALWAYS: 

624 raise ValueError( 

625 "Task-generated mapping within a mapped task group is not " 

626 "allowed with trigger rule 'always'" 

627 ) 

628 yield from self._iter_child(child) 

629 

630 def __exit__(self, exc_type, exc_val, exc_tb): 

631 for op, _ in self._expand_input.iter_references(): 

632 self.set_upstream(op) 

633 super().__exit__(exc_type, exc_val, exc_tb) 

634 

635 def iter_mapped_dependencies(self) -> Iterator[Operator]: 

636 """Upstream dependencies that provide XComs used by this mapped task group.""" 

637 from airflow.sdk.definitions.xcom_arg import XComArg 

638 

639 for op, _ in XComArg.iter_xcom_references(self._expand_input): 

640 yield op