Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/utils/task_group.py: 35%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

352 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""A collection of closely related tasks on the same DAG that should be grouped together visually.""" 

19 

20from __future__ import annotations 

21 

22import copy 

23import functools 

24import operator 

25import weakref 

26from typing import TYPE_CHECKING, Any, Generator, Iterator, Sequence 

27 

28import methodtools 

29import re2 

30 

31from airflow.exceptions import ( 

32 AirflowDagCycleException, 

33 AirflowException, 

34 DuplicateTaskIdFound, 

35 TaskAlreadyInTaskGroup, 

36) 

37from airflow.models.taskmixin import DAGNode 

38from airflow.serialization.enums import DagAttributeTypes 

39from airflow.utils.helpers import validate_group_key 

40 

41if TYPE_CHECKING: 

42 from sqlalchemy.orm import Session 

43 

44 from airflow.models.abstractoperator import AbstractOperator 

45 from airflow.models.baseoperator import BaseOperator 

46 from airflow.models.dag import DAG 

47 from airflow.models.expandinput import ExpandInput 

48 from airflow.models.operator import Operator 

49 from airflow.models.taskmixin import DependencyMixin 

50 from airflow.utils.edgemodifier import EdgeModifier 

51 

52 

53class TaskGroup(DAGNode): 

54 """ 

55 A collection of tasks. 

56 

57 When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across 

58 all tasks within the group if necessary. 

59 

60 :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict 

61 with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id 

62 set to None. 

63 :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with 

64 this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed. 

65 Default is True. 

66 :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None 

67 for the root TaskGroup. 

68 :param dag: The DAG that this TaskGroup belongs to. 

69 :param default_args: A dictionary of default parameters to be used 

70 as constructor keyword parameters when initialising operators, 

71 will override default_args defined in the DAG level. 

72 Note that operators have the same hook, and precede those defined 

73 here, meaning that if your dict contains `'depends_on_past': True` 

74 here and `'depends_on_past': False` in the operator's call 

75 `default_args`, the actual value will be `False`. 

76 :param tooltip: The tooltip of the TaskGroup node when displayed in the UI 

77 :param ui_color: The fill color of the TaskGroup node when displayed in the UI 

78 :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI 

79 :param add_suffix_on_collision: If this task group name already exists, 

80 automatically add `__1` etc suffixes 

81 """ 

82 

83 used_group_ids: set[str | None] 

84 

85 def __init__( 

86 self, 

87 group_id: str | None, 

88 prefix_group_id: bool = True, 

89 parent_group: TaskGroup | None = None, 

90 dag: DAG | None = None, 

91 default_args: dict[str, Any] | None = None, 

92 tooltip: str = "", 

93 ui_color: str = "CornflowerBlue", 

94 ui_fgcolor: str = "#000", 

95 add_suffix_on_collision: bool = False, 

96 ): 

97 from airflow.models.dag import DagContext 

98 

99 self.prefix_group_id = prefix_group_id 

100 self.default_args = copy.deepcopy(default_args or {}) 

101 

102 dag = dag or DagContext.get_current_dag() 

103 

104 if group_id is None: 

105 # This creates a root TaskGroup. 

106 if parent_group: 

107 raise AirflowException("Root TaskGroup cannot have parent_group") 

108 # used_group_ids is shared across all TaskGroups in the same DAG to keep track 

109 # of used group_id to avoid duplication. 

110 self.used_group_ids = set() 

111 self.dag = dag 

112 else: 

113 if prefix_group_id: 

114 # If group id is used as prefix, it should not contain spaces nor dots 

115 # because it is used as prefix in the task_id 

116 validate_group_key(group_id) 

117 else: 

118 if not isinstance(group_id, str): 

119 raise ValueError("group_id must be str") 

120 if not group_id: 

121 raise ValueError("group_id must not be empty") 

122 

123 if not parent_group and not dag: 

124 raise AirflowException("TaskGroup can only be used inside a dag") 

125 

126 parent_group = parent_group or TaskGroupContext.get_current_task_group(dag) 

127 if not parent_group: 

128 raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup") 

129 if dag is not parent_group.dag: 

130 raise RuntimeError( 

131 "Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag 

132 ) 

133 

134 self.used_group_ids = parent_group.used_group_ids 

135 

136 # if given group_id already used assign suffix by incrementing largest used suffix integer 

137 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 

138 self._group_id = group_id 

139 self._check_for_group_id_collisions(add_suffix_on_collision) 

140 

141 self.children: dict[str, DAGNode] = {} 

142 

143 if parent_group: 

144 parent_group.add(self) 

145 self._update_default_args(parent_group) 

146 

147 self.used_group_ids.add(self.group_id) 

148 if self.group_id: 

149 self.used_group_ids.add(self.downstream_join_id) 

150 self.used_group_ids.add(self.upstream_join_id) 

151 

152 self.tooltip = tooltip 

153 self.ui_color = ui_color 

154 self.ui_fgcolor = ui_fgcolor 

155 

156 # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately 

157 # so that we can optimize the number of edges when entire TaskGroups depend on each other. 

158 self.upstream_group_ids: set[str | None] = set() 

159 self.downstream_group_ids: set[str | None] = set() 

160 self.upstream_task_ids = set() 

161 self.downstream_task_ids = set() 

162 

163 def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): 

164 if self._group_id is None: 

165 return 

166 # if given group_id already used assign suffix by incrementing largest used suffix integer 

167 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 

168 if self._group_id in self.used_group_ids: 

169 if not add_suffix_on_collision: 

170 raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG") 

171 base = re2.split(r"__\d+$", self._group_id)[0] 

172 suffixes = sorted( 

173 int(re2.split(r"^.+__", used_group_id)[1]) 

174 for used_group_id in self.used_group_ids 

175 if used_group_id is not None and re2.match(rf"^{base}__\d+$", used_group_id) 

176 ) 

177 if not suffixes: 

178 self._group_id += "__1" 

179 else: 

180 self._group_id = f"{base}__{suffixes[-1] + 1}" 

181 

182 def _update_default_args(self, parent_group: TaskGroup): 

183 if parent_group.default_args: 

184 self.default_args = {**parent_group.default_args, **self.default_args} 

185 

186 @classmethod 

187 def create_root(cls, dag: DAG) -> TaskGroup: 

188 """Create a root TaskGroup with no group_id or parent.""" 

189 return cls(group_id=None, dag=dag) 

190 

191 @property 

192 def node_id(self): 

193 return self.group_id 

194 

195 @property 

196 def is_root(self) -> bool: 

197 """Returns True if this TaskGroup is the root TaskGroup. Otherwise False.""" 

198 return not self.group_id 

199 

200 @property 

201 def parent_group(self) -> TaskGroup | None: 

202 return self.task_group 

203 

204 def __iter__(self): 

205 for child in self.children.values(): 

206 if isinstance(child, TaskGroup): 

207 yield from child 

208 else: 

209 yield child 

210 

211 def add(self, task: DAGNode) -> DAGNode: 

212 """Add a task to this TaskGroup. 

213 

214 :meta private: 

215 """ 

216 from airflow.models.abstractoperator import AbstractOperator 

217 

218 if TaskGroupContext.active: 

219 if task.task_group and task.task_group != self: 

220 task.task_group.children.pop(task.node_id, None) 

221 task.task_group = self 

222 existing_tg = task.task_group 

223 if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self: 

224 raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id) 

225 

226 # Set the TG first, as setting it might change the return value of node_id! 

227 task.task_group = weakref.proxy(self) 

228 key = task.node_id 

229 

230 if key in self.children: 

231 node_type = "Task" if hasattr(task, "task_id") else "Task Group" 

232 raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG") 

233 

234 if isinstance(task, TaskGroup): 

235 if self.dag: 

236 if task.dag is not None and self.dag is not task.dag: 

237 raise RuntimeError( 

238 "Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag 

239 ) 

240 task.dag = self.dag 

241 if task.children: 

242 raise AirflowException("Cannot add a non-empty TaskGroup") 

243 

244 self.children[key] = task 

245 return task 

246 

247 def _remove(self, task: DAGNode) -> None: 

248 key = task.node_id 

249 

250 if key not in self.children: 

251 raise KeyError(f"Node id {key!r} not part of this task group") 

252 

253 self.used_group_ids.remove(key) 

254 del self.children[key] 

255 

256 @property 

257 def group_id(self) -> str | None: 

258 """group_id of this TaskGroup.""" 

259 if self.task_group and self.task_group.prefix_group_id and self.task_group._group_id: 

260 # defer to parent whether it adds a prefix 

261 return self.task_group.child_id(self._group_id) 

262 

263 return self._group_id 

264 

265 @property 

266 def label(self) -> str | None: 

267 """group_id excluding parent's group_id used as the node label in UI.""" 

268 return self._group_id 

269 

270 def update_relative( 

271 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None 

272 ) -> None: 

273 """ 

274 Override TaskMixin.update_relative. 

275 

276 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids 

277 accordingly so that we can reduce the number of edges when displaying Graph view. 

278 """ 

279 if isinstance(other, TaskGroup): 

280 # Handles setting relationship between a TaskGroup and another TaskGroup 

281 if upstream: 

282 parent, child = (self, other) 

283 if edge_modifier: 

284 edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id) 

285 else: 

286 parent, child = (other, self) 

287 if edge_modifier: 

288 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id) 

289 

290 parent.upstream_group_ids.add(child.group_id) 

291 child.downstream_group_ids.add(parent.group_id) 

292 else: 

293 # Handles setting relationship between a TaskGroup and a task 

294 for task in other.roots: 

295 if not isinstance(task, DAGNode): 

296 raise AirflowException( 

297 "Relationships can only be set between TaskGroup " 

298 f"or operators; received {task.__class__.__name__}" 

299 ) 

300 

301 # Do not set a relationship between a TaskGroup and a Label's roots 

302 if self == task: 

303 continue 

304 

305 if upstream: 

306 self.upstream_task_ids.add(task.node_id) 

307 if edge_modifier: 

308 edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id) 

309 else: 

310 self.downstream_task_ids.add(task.node_id) 

311 if edge_modifier: 

312 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id) 

313 

314 def _set_relatives( 

315 self, 

316 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

317 upstream: bool = False, 

318 edge_modifier: EdgeModifier | None = None, 

319 ) -> None: 

320 """ 

321 Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. 

322 

323 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids. 

324 """ 

325 if not isinstance(task_or_task_list, Sequence): 

326 task_or_task_list = [task_or_task_list] 

327 

328 for task_like in task_or_task_list: 

329 self.update_relative(task_like, upstream, edge_modifier=edge_modifier) 

330 

331 if upstream: 

332 for task in self.get_roots(): 

333 task.set_upstream(task_or_task_list) 

334 else: 

335 for task in self.get_leaves(): 

336 task.set_downstream(task_or_task_list) 

337 

338 def __enter__(self) -> TaskGroup: 

339 TaskGroupContext.push_context_managed_task_group(self) 

340 return self 

341 

342 def __exit__(self, _type, _value, _tb): 

343 TaskGroupContext.pop_context_managed_task_group() 

344 

345 def has_task(self, task: BaseOperator) -> bool: 

346 """Return True if this TaskGroup or its children TaskGroups contains the given task.""" 

347 if task.task_id in self.children: 

348 return True 

349 

350 return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup)) 

351 

352 @property 

353 def roots(self) -> list[BaseOperator]: 

354 """Required by TaskMixin.""" 

355 return list(self.get_roots()) 

356 

357 @property 

358 def leaves(self) -> list[BaseOperator]: 

359 """Required by TaskMixin.""" 

360 return list(self.get_leaves()) 

361 

362 def get_roots(self) -> Generator[BaseOperator, None, None]: 

363 """Return a generator of tasks with no upstream dependencies within the TaskGroup.""" 

364 tasks = list(self) 

365 ids = {x.task_id for x in tasks} 

366 for task in tasks: 

367 if task.upstream_task_ids.isdisjoint(ids): 

368 yield task 

369 

370 def get_leaves(self) -> Generator[BaseOperator, None, None]: 

371 """Return a generator of tasks with no downstream dependencies within the TaskGroup.""" 

372 tasks = list(self) 

373 ids = {x.task_id for x in tasks} 

374 

375 def has_non_teardown_downstream(task, exclude: str): 

376 for down_task in task.downstream_list: 

377 if down_task.task_id == exclude: 

378 continue 

379 elif down_task.task_id not in ids: 

380 continue 

381 elif not down_task.is_teardown: 

382 return True 

383 return False 

384 

385 def recurse_for_first_non_teardown(task): 

386 for upstream_task in task.upstream_list: 

387 if upstream_task.task_id not in ids: 

388 # upstream task is not in task group 

389 continue 

390 elif upstream_task.is_teardown: 

391 yield from recurse_for_first_non_teardown(upstream_task) 

392 elif task.is_teardown and upstream_task.is_setup: 

393 # don't go through the teardown-to-setup path 

394 continue 

395 # return unless upstream task already has non-teardown downstream in group 

396 elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id): 

397 yield upstream_task 

398 

399 for task in tasks: 

400 if task.downstream_task_ids.isdisjoint(ids): 

401 if not task.is_teardown: 

402 yield task 

403 else: 

404 yield from recurse_for_first_non_teardown(task) 

405 

406 def child_id(self, label): 

407 """Prefix label with group_id if prefix_group_id is True. Otherwise return the label as-is.""" 

408 if self.prefix_group_id: 

409 group_id = self.group_id 

410 if group_id: 

411 return f"{group_id}.{label}" 

412 

413 return label 

414 

415 @property 

416 def upstream_join_id(self) -> str: 

417 """ 

418 Creates a unique ID for upstream dependencies of this TaskGroup. 

419 

420 If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called 

421 upstream_join_id will be created in Graph view to join the outgoing edges from this 

422 TaskGroup to reduce the total number of edges needed to be displayed. 

423 """ 

424 return f"{self.group_id}.upstream_join_id" 

425 

426 @property 

427 def downstream_join_id(self) -> str: 

428 """ 

429 Creates a unique ID for downstream dependencies of this TaskGroup. 

430 

431 If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called 

432 downstream_join_id will be created in Graph view to join the outgoing edges from this 

433 TaskGroup to reduce the total number of edges needed to be displayed. 

434 """ 

435 return f"{self.group_id}.downstream_join_id" 

436 

437 def get_task_group_dict(self) -> dict[str, TaskGroup]: 

438 """Return a flat dictionary of group_id: TaskGroup.""" 

439 task_group_map = {} 

440 

441 def build_map(task_group): 

442 if not isinstance(task_group, TaskGroup): 

443 return 

444 

445 task_group_map[task_group.group_id] = task_group 

446 

447 for child in task_group.children.values(): 

448 build_map(child) 

449 

450 build_map(self) 

451 return task_group_map 

452 

453 def get_child_by_label(self, label: str) -> DAGNode: 

454 """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix).""" 

455 return self.children[self.child_id(label)] 

456 

457 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

458 """Serialize task group; required by DAGNode.""" 

459 from airflow.serialization.serialized_objects import TaskGroupSerialization 

460 

461 return DagAttributeTypes.TASK_GROUP, TaskGroupSerialization.serialize_task_group(self) 

462 

463 def hierarchical_alphabetical_sort(self): 

464 """ 

465 Sort children in hierarchical alphabetical order. 

466 

467 - groups in alphabetical order first 

468 - tasks in alphabetical order after them. 

469 

470 :return: list of tasks in hierarchical alphabetical order 

471 """ 

472 return sorted( 

473 self.children.values(), key=lambda node: (not isinstance(node, TaskGroup), node.node_id) 

474 ) 

475 

476 def topological_sort(self, _include_subdag_tasks: bool = False): 

477 """ 

478 Sorts children in topographical order, such that a task comes after any of its upstream dependencies. 

479 

480 :return: list of tasks in topological order 

481 """ 

482 # This uses a modified version of Kahn's Topological Sort algorithm to 

483 # not have to pre-compute the "in-degree" of the nodes. 

484 from airflow.operators.subdag import SubDagOperator # Avoid circular import 

485 

486 graph_unsorted = copy.copy(self.children) 

487 

488 graph_sorted: list[DAGNode] = [] 

489 

490 # special case 

491 if not self.children: 

492 return graph_sorted 

493 

494 # Run until the unsorted graph is empty. 

495 while graph_unsorted: 

496 # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain 

497 # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the 

498 # pair from the unsorted graph, and append it to the sorted graph. Note here that by using 

499 # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify 

500 # the unsorted graph as we move through it. 

501 # 

502 # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved 

503 # during each pass through the graph. If not, we need to exit as the graph therefore can't be 

504 # sorted. 

505 acyclic = False 

506 for node in list(graph_unsorted.values()): 

507 for edge in node.upstream_list: 

508 if edge.node_id in graph_unsorted: 

509 break 

510 # Check for task's group is a child (or grand child) of this TG, 

511 tg = edge.task_group 

512 while tg: 

513 if tg.node_id in graph_unsorted: 

514 break 

515 tg = tg.task_group 

516 

517 if tg: 

518 # We are already going to visit that TG 

519 break 

520 else: 

521 acyclic = True 

522 del graph_unsorted[node.node_id] 

523 graph_sorted.append(node) 

524 if _include_subdag_tasks and isinstance(node, SubDagOperator): 

525 graph_sorted.extend( 

526 node.subdag.task_group.topological_sort(_include_subdag_tasks=True) 

527 ) 

528 

529 if not acyclic: 

530 raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}") 

531 

532 return graph_sorted 

533 

534 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

535 """Return mapped task groups in the hierarchy. 

536 

537 Groups are returned from the closest to the outmost. If *self* is a 

538 mapped task group, it is returned first. 

539 

540 :meta private: 

541 """ 

542 group: TaskGroup | None = self 

543 while group is not None: 

544 if isinstance(group, MappedTaskGroup): 

545 yield group 

546 group = group.task_group 

547 

548 def iter_tasks(self) -> Iterator[AbstractOperator]: 

549 """Return an iterator of the child tasks.""" 

550 from airflow.models.abstractoperator import AbstractOperator 

551 

552 groups_to_visit = [self] 

553 

554 while groups_to_visit: 

555 visiting = groups_to_visit.pop(0) 

556 

557 for child in visiting.children.values(): 

558 if isinstance(child, AbstractOperator): 

559 yield child 

560 elif isinstance(child, TaskGroup): 

561 groups_to_visit.append(child) 

562 else: 

563 raise ValueError( 

564 f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}" 

565 ) 

566 

567 

568class MappedTaskGroup(TaskGroup): 

569 """A mapped task group. 

570 

571 This doesn't really do anything special, just holds some additional metadata 

572 for expansion later. 

573 

574 Don't instantiate this class directly; call *expand* or *expand_kwargs* on 

575 a ``@task_group`` function instead. 

576 """ 

577 

578 def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None: 

579 super().__init__(**kwargs) 

580 self._expand_input = expand_input 

581 

582 def iter_mapped_dependencies(self) -> Iterator[Operator]: 

583 """Upstream dependencies that provide XComs used by this mapped task group.""" 

584 from airflow.models.xcom_arg import XComArg 

585 

586 for op, _ in XComArg.iter_xcom_references(self._expand_input): 

587 yield op 

588 

589 @methodtools.lru_cache(maxsize=None) 

590 def get_parse_time_mapped_ti_count(self) -> int: 

591 """ 

592 Return the Number of instances a task in this group should be mapped to, when a DAG run is created. 

593 

594 This only considers literal mapped arguments, and would return *None* 

595 when any non-literal values are used for mapping. 

596 

597 If this group is inside mapped task groups, all the nested counts are 

598 multiplied and accounted. 

599 

600 :meta private: 

601 

602 :raise NotFullyPopulated: If any non-literal mapped arguments are encountered. 

603 :return: The total number of mapped instances each task should have. 

604 """ 

605 return functools.reduce( 

606 operator.mul, 

607 (g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()), 

608 ) 

609 

610 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: 

611 """ 

612 Return the number of instances a task in this group should be mapped to at run time. 

613 

614 This considers both literal and non-literal mapped arguments, and the 

615 result is therefore available when all depended tasks have finished. The 

616 return value should be identical to ``parse_time_mapped_ti_count`` if 

617 all mapped arguments are literal. 

618 

619 If this group is inside mapped task groups, all the nested counts are 

620 multiplied and accounted. 

621 

622 :meta private: 

623 

624 :raise NotFullyPopulated: If upstream tasks are not all complete yet. 

625 :return: Total number of mapped TIs this task should have. 

626 """ 

627 groups = self.iter_mapped_task_groups() 

628 return functools.reduce( 

629 operator.mul, 

630 (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), 

631 ) 

632 

633 def __exit__(self, exc_type, exc_val, exc_tb): 

634 for op, _ in self._expand_input.iter_references(): 

635 self.set_upstream(op) 

636 super().__exit__(exc_type, exc_val, exc_tb) 

637 

638 

639class TaskGroupContext: 

640 """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" 

641 

642 active: bool = False 

643 _context_managed_task_group: TaskGroup | None = None 

644 _previous_context_managed_task_groups: list[TaskGroup] = [] 

645 

646 @classmethod 

647 def push_context_managed_task_group(cls, task_group: TaskGroup): 

648 """Push a TaskGroup into the list of managed TaskGroups.""" 

649 if cls._context_managed_task_group: 

650 cls._previous_context_managed_task_groups.append(cls._context_managed_task_group) 

651 cls._context_managed_task_group = task_group 

652 cls.active = True 

653 

654 @classmethod 

655 def pop_context_managed_task_group(cls) -> TaskGroup | None: 

656 """Pops the last TaskGroup from the list of managed TaskGroups and update the current TaskGroup.""" 

657 old_task_group = cls._context_managed_task_group 

658 if cls._previous_context_managed_task_groups: 

659 cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop() 

660 else: 

661 cls._context_managed_task_group = None 

662 cls.active = False 

663 return old_task_group 

664 

665 @classmethod 

666 def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None: 

667 """Get the current TaskGroup.""" 

668 from airflow.models.dag import DagContext 

669 

670 if not cls._context_managed_task_group: 

671 dag = dag or DagContext.get_current_dag() 

672 if dag: 

673 # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag. 

674 return dag.task_group 

675 

676 return cls._context_managed_task_group 

677 

678 

679def task_group_to_dict(task_item_or_group): 

680 """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" 

681 from airflow.models.abstractoperator import AbstractOperator 

682 from airflow.models.mappedoperator import MappedOperator 

683 

684 if isinstance(task := task_item_or_group, AbstractOperator): 

685 setup_teardown_type = {} 

686 is_mapped = {} 

687 if task.is_setup is True: 

688 setup_teardown_type["setupTeardownType"] = "setup" 

689 elif task.is_teardown is True: 

690 setup_teardown_type["setupTeardownType"] = "teardown" 

691 if isinstance(task, MappedOperator): 

692 is_mapped["isMapped"] = True 

693 return { 

694 "id": task.task_id, 

695 "value": { 

696 "label": task.label, 

697 "labelStyle": f"fill:{task.ui_fgcolor};", 

698 "style": f"fill:{task.ui_color};", 

699 "rx": 5, 

700 "ry": 5, 

701 **is_mapped, 

702 **setup_teardown_type, 

703 }, 

704 } 

705 task_group = task_item_or_group 

706 is_mapped = isinstance(task_group, MappedTaskGroup) 

707 children = [ 

708 task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label) 

709 ] 

710 

711 if task_group.upstream_group_ids or task_group.upstream_task_ids: 

712 children.append( 

713 { 

714 "id": task_group.upstream_join_id, 

715 "value": { 

716 "label": "", 

717 "labelStyle": f"fill:{task_group.ui_fgcolor};", 

718 "style": f"fill:{task_group.ui_color};", 

719 "shape": "circle", 

720 }, 

721 } 

722 ) 

723 

724 if task_group.downstream_group_ids or task_group.downstream_task_ids: 

725 # This is the join node used to reduce the number of edges between two TaskGroup. 

726 children.append( 

727 { 

728 "id": task_group.downstream_join_id, 

729 "value": { 

730 "label": "", 

731 "labelStyle": f"fill:{task_group.ui_fgcolor};", 

732 "style": f"fill:{task_group.ui_color};", 

733 "shape": "circle", 

734 }, 

735 } 

736 ) 

737 

738 return { 

739 "id": task_group.group_id, 

740 "value": { 

741 "label": task_group.label, 

742 "labelStyle": f"fill:{task_group.ui_fgcolor};", 

743 "style": f"fill:{task_group.ui_color}", 

744 "rx": 5, 

745 "ry": 5, 

746 "clusterLabelPos": "top", 

747 "tooltip": task_group.tooltip, 

748 "isMapped": is_mapped, 

749 }, 

750 "children": children, 

751 }