Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/taskgroup.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

328 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""A collection of closely related tasks on the same Dag that should be grouped together visually.""" 

19 

20from __future__ import annotations 

21 

22import copy 

23import functools 

24import operator 

25import re 

26import weakref 

27from collections.abc import Generator, Iterator, Sequence 

28from typing import TYPE_CHECKING, Any 

29 

30import attrs 

31import methodtools 

32 

33from airflow.sdk import TriggerRule 

34from airflow.sdk.definitions._internal.node import DAGNode, validate_group_key 

35from airflow.sdk.exceptions import ( 

36 AirflowDagCycleException, 

37 DuplicateTaskIdFound, 

38 TaskAlreadyInTaskGroup, 

39) 

40 

41if TYPE_CHECKING: 

42 from airflow.sdk.bases.operator import BaseOperator 

43 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

44 from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput 

45 from airflow.sdk.definitions._internal.mixins import DependencyMixin 

46 from airflow.sdk.definitions.dag import DAG 

47 from airflow.sdk.definitions.edges import EdgeModifier 

48 from airflow.sdk.types import Operator 

49 from airflow.serialization.enums import DagAttributeTypes 

50 

51 

52def _default_parent_group() -> TaskGroup | None: 

53 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext 

54 

55 return TaskGroupContext.get_current() 

56 

57 

58def _parent_used_group_ids(tg: TaskGroup) -> set: 

59 if tg.parent_group: 

60 return tg.parent_group.used_group_ids 

61 return set() 

62 

63 

64# This could be achieved with `@dag.default` and make this a method, but for some unknown reason when we do 

65# that it makes Mypy (1.9.0 and 1.13.0 tested) seem to entirely loose track that this is an Attrs class. So 

66# we've gone with this and moved on with our lives, mypy is to much of a dark beast to battle over this. 

67def _default_dag(instance: TaskGroup): 

68 from airflow.sdk.definitions._internal.contextmanager import DagContext 

69 

70 if (pg := instance.parent_group) is not None: 

71 return pg.dag 

72 return DagContext.get_current() 

73 

74 

75# Mypy does not like a lambda for some reason. An explicit annotated function makes it happy. 

76def _validate_group_id(instance, attribute, value: str) -> None: 

77 validate_group_key(value) 

78 

79 

80@attrs.define(repr=False) 

81class TaskGroup(DAGNode): 

82 """ 

83 A collection of tasks. 

84 

85 When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across 

86 all tasks within the group if necessary. 

87 

88 :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict 

89 with group_id of TaskGroup or task_id of tasks in the Dag. Root TaskGroup has group_id 

90 set to None. 

91 :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with 

92 this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed. 

93 Default is True. 

94 :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None 

95 for the root TaskGroup. 

96 :param dag: The Dag that this TaskGroup belongs to. 

97 :param default_args: A dictionary of default parameters to be used 

98 as constructor keyword parameters when initialising operators, 

99 will override default_args defined in the Dag level. 

100 Note that operators have the same hook, and precede those defined 

101 here, meaning that if your dict contains `'depends_on_past': True` 

102 here and `'depends_on_past': False` in the operator's call 

103 `default_args`, the actual value will be `False`. 

104 :param tooltip: The tooltip of the TaskGroup node when displayed in the UI 

105 :param ui_color: The fill color of the TaskGroup node when displayed in the UI 

106 :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI 

107 :param add_suffix_on_collision: If this task group name already exists, 

108 automatically add `__1` etc suffixes 

109 :param group_display_name: If set, this will be the display name for the TaskGroup node in the UI. 

110 """ 

111 

112 _group_id: str | None = attrs.field( 

113 validator=attrs.validators.optional(_validate_group_id), 

114 # This is the default behaviour for attrs, but by specifying this it makes IDEs happier 

115 alias="group_id", 

116 ) 

117 group_display_name: str = attrs.field(default="", validator=attrs.validators.instance_of(str)) 

118 prefix_group_id: bool = attrs.field(default=True) 

119 parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group) 

120 dag: DAG = attrs.field(default=attrs.Factory(_default_dag, takes_self=True)) 

121 default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.deepcopy) 

122 tooltip: str = attrs.field(default="", validator=attrs.validators.instance_of(str)) 

123 children: dict[str, DAGNode] = attrs.field(factory=dict, init=False) 

124 

125 upstream_group_ids: set[str | None] = attrs.field(factory=set, init=False) 

126 downstream_group_ids: set[str | None] = attrs.field(factory=set, init=False) 

127 upstream_task_ids: set[str] = attrs.field(factory=set, init=False) 

128 downstream_task_ids: set[str] = attrs.field(factory=set, init=False) 

129 

130 used_group_ids: set[str] = attrs.field( 

131 default=attrs.Factory(_parent_used_group_ids, takes_self=True), 

132 init=False, 

133 on_setattr=attrs.setters.frozen, 

134 ) 

135 

136 ui_color: str = attrs.field(default="CornflowerBlue", validator=attrs.validators.instance_of(str)) 

137 ui_fgcolor: str = attrs.field(default="#000", validator=attrs.validators.instance_of(str)) 

138 

139 add_suffix_on_collision: bool = False 

140 

141 @dag.validator 

142 def _validate_dag(self, _attr, dag): 

143 if not dag: 

144 raise RuntimeError("TaskGroup can only be used inside a dag") 

145 

146 def __attrs_post_init__(self): 

147 # TODO: If attrs supported init only args we could use that here 

148 # https://github.com/python-attrs/attrs/issues/342 

149 self._check_for_group_id_collisions(self.add_suffix_on_collision) 

150 

151 if self._group_id and not self.parent_group and self.dag: 

152 # Support `tg = TaskGroup(x, dag=dag)` 

153 self.parent_group = self.dag.task_group 

154 

155 if self.parent_group: 

156 self.parent_group.add(self) 

157 if self.parent_group.default_args: 

158 self.default_args = { 

159 **self.parent_group.default_args, 

160 **self.default_args, 

161 } 

162 

163 if self._group_id: 

164 self.used_group_ids.add(self.group_id) 

165 self.used_group_ids.add(self.downstream_join_id) 

166 self.used_group_ids.add(self.upstream_join_id) 

167 

168 def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): 

169 if self._group_id is None: 

170 return 

171 # if given group_id already used assign suffix by incrementing largest used suffix integer 

172 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 

173 if self.group_id in self.used_group_ids: 

174 if not add_suffix_on_collision: 

175 raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG") 

176 base = re.split(r"__\d+$", self._group_id)[0] 

177 suffixes = sorted( 

178 int(re.split(r"^.+__", used_group_id)[1]) 

179 for used_group_id in self.used_group_ids 

180 if used_group_id is not None and re.match(rf"^{base}__\d+$", used_group_id) 

181 ) 

182 if not suffixes: 

183 self._group_id += "__1" 

184 else: 

185 self._group_id = f"{base}__{suffixes[-1] + 1}" 

186 

187 @classmethod 

188 def create_root(cls, dag: DAG) -> TaskGroup: 

189 """Create a root TaskGroup with no group_id or parent.""" 

190 return cls(group_id=None, dag=dag, parent_group=None) 

191 

192 @property 

193 def node_id(self): 

194 return self.group_id 

195 

196 @property 

197 def is_root(self) -> bool: 

198 """Returns True if this TaskGroup is the root TaskGroup. Otherwise False.""" 

199 return not self._group_id 

200 

201 @property 

202 def task_group(self) -> TaskGroup | None: 

203 return self.parent_group 

204 

205 @task_group.setter 

206 def task_group(self, value: TaskGroup | None): 

207 self.parent_group = value 

208 

209 def __iter__(self): 

210 for child in self.children.values(): 

211 yield from self._iter_child(child) 

212 

213 @staticmethod 

214 def _iter_child(child): 

215 """Iterate over the children of this TaskGroup.""" 

216 if isinstance(child, TaskGroup): 

217 yield from child 

218 else: 

219 yield child 

220 

221 def add(self, task: DAGNode) -> DAGNode: 

222 """ 

223 Add a task or TaskGroup to this TaskGroup. 

224 

225 :meta private: 

226 """ 

227 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

228 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext 

229 

230 if TaskGroupContext.active: 

231 if task.task_group and task.task_group != self: 

232 task.task_group.children.pop(task.node_id, None) 

233 task.task_group = self 

234 existing_tg = task.task_group 

235 if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self: 

236 raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id) 

237 

238 # Set the TG first, as setting it might change the return value of node_id! 

239 task.task_group = weakref.proxy(self) 

240 key = task.node_id 

241 

242 if key in self.children: 

243 node_type = "Task" if hasattr(task, "task_id") else "Task Group" 

244 raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the Dag") 

245 

246 if isinstance(task, TaskGroup): 

247 if self.dag: 

248 if task.dag is not None and self.dag is not task.dag: 

249 raise ValueError( 

250 "Cannot mix TaskGroups from different Dags: %s and %s", 

251 self.dag, 

252 task.dag, 

253 ) 

254 task.dag = self.dag 

255 if task.children: 

256 raise ValueError("Cannot add a non-empty TaskGroup") 

257 

258 self.children[key] = task 

259 return task 

260 

261 def _remove(self, task: DAGNode) -> None: 

262 key = task.node_id 

263 

264 if key not in self.children: 

265 raise KeyError(f"Node id {key!r} not part of this task group") 

266 

267 self.used_group_ids.remove(key) 

268 del self.children[key] 

269 

270 @property 

271 def group_id(self) -> str | None: 

272 """group_id of this TaskGroup.""" 

273 if ( 

274 self._group_id 

275 and self.parent_group 

276 and self.parent_group.prefix_group_id 

277 and self.parent_group._group_id 

278 ): 

279 # defer to parent whether it adds a prefix 

280 return self.parent_group.child_id(self._group_id) 

281 return self._group_id 

282 

283 @property 

284 def label(self) -> str | None: 

285 """group_id excluding parent's group_id used as the node label in UI.""" 

286 return self.group_display_name or self._group_id 

287 

288 def update_relative( 

289 self, 

290 other: DependencyMixin, 

291 upstream: bool = True, 

292 edge_modifier: EdgeModifier | None = None, 

293 ) -> None: 

294 """ 

295 Override TaskMixin.update_relative. 

296 

297 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids 

298 accordingly so that we can reduce the number of edges when displaying Graph view. 

299 """ 

300 if isinstance(other, TaskGroup): 

301 # Handles setting relationship between a TaskGroup and another TaskGroup 

302 if upstream: 

303 parent, child = (self, other) 

304 if edge_modifier: 

305 edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id) 

306 else: 

307 parent, child = (other, self) 

308 if edge_modifier: 

309 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id) 

310 

311 parent.upstream_group_ids.add(child.group_id) 

312 child.downstream_group_ids.add(parent.group_id) 

313 else: 

314 # Handles setting relationship between a TaskGroup and a task 

315 for task in other.roots: 

316 if not isinstance(task, DAGNode): 

317 raise RuntimeError( 

318 "Relationships can only be set between TaskGroup " 

319 f"or operators; received {task.__class__.__name__}" 

320 ) 

321 

322 # Do not set a relationship between a TaskGroup and a Label's roots 

323 if self == task: 

324 continue 

325 

326 if upstream: 

327 self.upstream_task_ids.add(task.node_id) 

328 if edge_modifier: 

329 edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id) 

330 else: 

331 self.downstream_task_ids.add(task.node_id) 

332 if edge_modifier: 

333 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id) 

334 

335 def _set_relatives( 

336 self, 

337 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

338 upstream: bool = False, 

339 edge_modifier: EdgeModifier | None = None, 

340 ) -> None: 

341 """ 

342 Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. 

343 

344 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids. 

345 """ 

346 if not isinstance(task_or_task_list, Sequence): 

347 task_or_task_list = [task_or_task_list] 

348 

349 # Helper function to find leaves from a task list or task group 

350 def find_leaves(group_or_task) -> list[Any]: 

351 while group_or_task: 

352 group_or_task_leaves = list(group_or_task.get_leaves()) 

353 if group_or_task_leaves: 

354 return group_or_task_leaves 

355 if group_or_task.upstream_task_ids: 

356 upstream_task_ids_list = list(group_or_task.upstream_task_ids) 

357 return [self.dag.get_task(task_id) for task_id in upstream_task_ids_list] 

358 group_or_task = group_or_task.parent_group 

359 return [] 

360 

361 # Check if the current TaskGroup is empty 

362 leaves = find_leaves(self) 

363 

364 for task_like in task_or_task_list: 

365 self.update_relative(task_like, upstream, edge_modifier=edge_modifier) 

366 

367 if upstream: 

368 for task in self.get_roots(): 

369 task.set_upstream(task_or_task_list) 

370 else: 

371 for task in leaves: # Use the fetched leaves 

372 task.set_downstream(task_or_task_list) 

373 

374 def __enter__(self) -> TaskGroup: 

375 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext 

376 

377 TaskGroupContext.push(self) 

378 return self 

379 

380 def __exit__(self, _type, _value, _tb): 

381 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext 

382 

383 TaskGroupContext.pop() 

384 

385 def has_task(self, task: BaseOperator) -> bool: 

386 """Return True if this TaskGroup or its children TaskGroups contains the given task.""" 

387 if task.task_id in self.children: 

388 return True 

389 

390 return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup)) 

391 

392 @property 

393 def roots(self) -> list[BaseOperator]: 

394 """Required by DependencyMixin.""" 

395 return list(self.get_roots()) 

396 

397 @property 

398 def leaves(self) -> list[BaseOperator]: 

399 """Required by DependencyMixin.""" 

400 return list(self.get_leaves()) 

401 

402 def get_roots(self) -> Generator[BaseOperator, None, None]: 

403 """Return a generator of tasks with no upstream dependencies within the TaskGroup.""" 

404 tasks = list(self) 

405 ids = {x.task_id for x in tasks} 

406 for task in tasks: 

407 if task.upstream_task_ids.isdisjoint(ids): 

408 yield task 

409 

410 def get_leaves(self) -> Generator[BaseOperator, None, None]: 

411 """Return a generator of tasks with no downstream dependencies within the TaskGroup.""" 

412 tasks = list(self) 

413 ids = {x.task_id for x in tasks} 

414 

415 def has_non_teardown_downstream(task, exclude: str): 

416 for down_task in task.downstream_list: 

417 if down_task.task_id == exclude: 

418 continue 

419 if down_task.task_id not in ids: 

420 continue 

421 if not down_task.is_teardown: 

422 return True 

423 return False 

424 

425 def recurse_for_first_non_teardown(task): 

426 for upstream_task in task.upstream_list: 

427 if upstream_task.task_id not in ids: 

428 # upstream task is not in task group 

429 continue 

430 elif upstream_task.is_teardown: 

431 yield from recurse_for_first_non_teardown(upstream_task) 

432 elif task.is_teardown and upstream_task.is_setup: 

433 # don't go through the teardown-to-setup path 

434 continue 

435 # return unless upstream task already has non-teardown downstream in group 

436 elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id): 

437 yield upstream_task 

438 

439 for task in tasks: 

440 if task.downstream_task_ids.isdisjoint(ids): 

441 if not task.is_teardown: 

442 yield task 

443 else: 

444 yield from recurse_for_first_non_teardown(task) 

445 

446 def child_id(self, label): 

447 """Prefix label with group_id if prefix_group_id is True. Otherwise return the label as-is.""" 

448 if self.prefix_group_id: 

449 group_id = self.group_id 

450 if group_id: 

451 return f"{group_id}.{label}" 

452 

453 return label 

454 

455 @property 

456 def upstream_join_id(self) -> str: 

457 """ 

458 Creates a unique ID for upstream dependencies of this TaskGroup. 

459 

460 If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called 

461 upstream_join_id will be created in Graph view to join the outgoing edges from this 

462 TaskGroup to reduce the total number of edges needed to be displayed. 

463 """ 

464 return f"{self.group_id}.upstream_join_id" 

465 

466 @property 

467 def downstream_join_id(self) -> str: 

468 """ 

469 Creates a unique ID for downstream dependencies of this TaskGroup. 

470 

471 If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called 

472 downstream_join_id will be created in Graph view to join the outgoing edges from this 

473 TaskGroup to reduce the total number of edges needed to be displayed. 

474 """ 

475 return f"{self.group_id}.downstream_join_id" 

476 

477 def get_task_group_dict(self) -> dict[str, TaskGroup]: 

478 """Return a flat dictionary of group_id: TaskGroup.""" 

479 task_group_map = {} 

480 

481 def build_map(task_group): 

482 if not isinstance(task_group, TaskGroup): 

483 return 

484 

485 task_group_map[task_group.group_id] = task_group 

486 

487 for child in task_group.children.values(): 

488 build_map(child) 

489 

490 build_map(self) 

491 return task_group_map 

492 

493 def get_child_by_label(self, label: str) -> DAGNode: 

494 """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix).""" 

495 return self.children[self.child_id(label)] 

496 

497 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

498 """Serialize task group; required by DagNode.""" 

499 from airflow.serialization.enums import DagAttributeTypes 

500 from airflow.serialization.serialized_objects import TaskGroupSerialization 

501 

502 return ( 

503 DagAttributeTypes.TASK_GROUP, 

504 TaskGroupSerialization.serialize_task_group(self), 

505 ) 

506 

507 def hierarchical_alphabetical_sort(self): 

508 """ 

509 Sort children in hierarchical alphabetical order. 

510 

511 - groups in alphabetical order first 

512 - tasks in alphabetical order after them. 

513 

514 :return: list of tasks in hierarchical alphabetical order 

515 """ 

516 return sorted( 

517 self.children.values(), 

518 key=lambda node: (not isinstance(node, TaskGroup), node.node_id), 

519 ) 

520 

521 def topological_sort(self): 

522 """ 

523 Sorts children in topographical order, such that a task comes after any of its upstream dependencies. 

524 

525 :return: list of tasks in topological order 

526 """ 

527 # This uses a modified version of Kahn's Topological Sort algorithm to 

528 # not have to pre-compute the "in-degree" of the nodes. 

529 graph_unsorted = copy.copy(self.children) 

530 

531 graph_sorted: list[DAGNode] = [] 

532 

533 # special case 

534 if not self.children: 

535 return graph_sorted 

536 

537 # Run until the unsorted graph is empty. 

538 while graph_unsorted: 

539 # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain 

540 # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the 

541 # pair from the unsorted graph, and append it to the sorted graph. Note here that by using 

542 # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify 

543 # the unsorted graph as we move through it. 

544 # 

545 # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved 

546 # during each pass through the graph. If not, we need to exit as the graph therefore can't be 

547 # sorted. 

548 acyclic = False 

549 for node in list(graph_unsorted.values()): 

550 for edge in node.upstream_list: 

551 if edge.node_id in graph_unsorted: 

552 break 

553 # Check for task's group is a child (or grand child) of this TG, 

554 tg = edge.task_group 

555 while tg: 

556 if tg.node_id in graph_unsorted: 

557 break 

558 tg = tg.parent_group 

559 

560 if tg: 

561 # We are already going to visit that TG 

562 break 

563 else: 

564 acyclic = True 

565 del graph_unsorted[node.node_id] 

566 graph_sorted.append(node) 

567 

568 if not acyclic: 

569 raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}") 

570 

571 return graph_sorted 

572 

573 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

574 """ 

575 Return mapped task groups in the hierarchy. 

576 

577 Groups are returned from the closest to the outmost. If *self* is a 

578 mapped task group, it is returned first. 

579 

580 :meta private: 

581 """ 

582 group: TaskGroup | None = self 

583 while group is not None: 

584 if isinstance(group, MappedTaskGroup): 

585 yield group 

586 group = group.parent_group 

587 

588 def iter_tasks(self) -> Iterator[AbstractOperator]: 

589 """Return an iterator of the child tasks.""" 

590 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

591 

592 groups_to_visit = [self] 

593 

594 while groups_to_visit: 

595 visiting = groups_to_visit.pop(0) 

596 

597 for child in visiting.children.values(): 

598 if isinstance(child, AbstractOperator): 

599 yield child 

600 elif isinstance(child, TaskGroup): 

601 groups_to_visit.append(child) 

602 else: 

603 raise ValueError( 

604 f"Encountered a DAGNode that is not a TaskGroup or an " 

605 f"AbstractOperator: {type(child).__module__}.{type(child)}" 

606 ) 

607 

608 

609@attrs.define(kw_only=True, repr=False) 

610class MappedTaskGroup(TaskGroup): 

611 """ 

612 A mapped task group. 

613 

614 This doesn't really do anything special, just holds some additional metadata 

615 for expansion later. 

616 

617 Don't instantiate this class directly; call *expand* or *expand_kwargs* on 

618 a ``@task_group`` function instead. 

619 """ 

620 

621 _expand_input: DictOfListsExpandInput | ListOfDictsExpandInput = attrs.field(alias="expand_input") 

622 

623 def __iter__(self): 

624 for child in self.children.values(): 

625 if getattr(child, "trigger_rule", None) == TriggerRule.ALWAYS: 

626 raise ValueError( 

627 "Task-generated mapping within a mapped task group is not " 

628 "allowed with trigger rule 'always'" 

629 ) 

630 yield from self._iter_child(child) 

631 

632 @methodtools.lru_cache(maxsize=None) 

633 def get_parse_time_mapped_ti_count(self) -> int: 

634 """ 

635 Return the Number of instances a task in this group should be mapped to, when a Dag run is created. 

636 

637 This only considers literal mapped arguments, and would return *None* 

638 when any non-literal values are used for mapping. 

639 

640 If this group is inside mapped task groups, all the nested counts are 

641 multiplied and accounted. 

642 

643 :meta private: 

644 

645 :raise NotFullyPopulated: If any non-literal mapped arguments are encountered. 

646 :return: The total number of mapped instances each task should have. 

647 """ 

648 return functools.reduce( 

649 operator.mul, 

650 (g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()), 

651 ) 

652 

653 def __exit__(self, exc_type, exc_val, exc_tb): 

654 for op, _ in self._expand_input.iter_references(): 

655 self.set_upstream(op) 

656 super().__exit__(exc_type, exc_val, exc_tb) 

657 

658 def iter_mapped_dependencies(self) -> Iterator[Operator]: 

659 """Upstream dependencies that provide XComs used by this mapped task group.""" 

660 from airflow.sdk.definitions.xcom_arg import XComArg 

661 

662 for op, _ in XComArg.iter_xcom_references(self._expand_input): 

663 yield op