Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/utils/task_group.py: 38%

305 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18""" 

19A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped 

20together when the DAG is displayed graphically. 

21""" 

22from __future__ import annotations 

23 

24import copy 

25import functools 

26import operator 

27import re 

28import weakref 

29from typing import TYPE_CHECKING, Any, Generator, Iterator, Sequence 

30 

31from airflow.compat.functools import cache 

32from airflow.exceptions import ( 

33 AirflowDagCycleException, 

34 AirflowException, 

35 DuplicateTaskIdFound, 

36 TaskAlreadyInTaskGroup, 

37) 

38from airflow.models.taskmixin import DAGNode, DependencyMixin 

39from airflow.serialization.enums import DagAttributeTypes 

40from airflow.utils.helpers import validate_group_key 

41 

42if TYPE_CHECKING: 

43 from sqlalchemy.orm import Session 

44 

45 from airflow.models.abstractoperator import AbstractOperator 

46 from airflow.models.baseoperator import BaseOperator 

47 from airflow.models.dag import DAG 

48 from airflow.models.expandinput import ExpandInput 

49 from airflow.models.operator import Operator 

50 from airflow.utils.edgemodifier import EdgeModifier 

51 

52 

53class TaskGroup(DAGNode): 

54 """ 

55 A collection of tasks. When set_downstream() or set_upstream() are called on the 

56 TaskGroup, it is applied across all tasks within the group if necessary. 

57 

58 :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict 

59 with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id 

60 set to None. 

61 :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with 

62 this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed. 

63 Default is True. 

64 :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None 

65 for the root TaskGroup. 

66 :param dag: The DAG that this TaskGroup belongs to. 

67 :param default_args: A dictionary of default parameters to be used 

68 as constructor keyword parameters when initialising operators, 

69 will override default_args defined in the DAG level. 

70 Note that operators have the same hook, and precede those defined 

71 here, meaning that if your dict contains `'depends_on_past': True` 

72 here and `'depends_on_past': False` in the operator's call 

73 `default_args`, the actual value will be `False`. 

74 :param tooltip: The tooltip of the TaskGroup node when displayed in the UI 

75 :param ui_color: The fill color of the TaskGroup node when displayed in the UI 

76 :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI 

77 :param add_suffix_on_collision: If this task group name already exists, 

78 automatically add `__1` etc suffixes 

79 """ 

80 

81 used_group_ids: set[str | None] 

82 

83 def __init__( 

84 self, 

85 group_id: str | None, 

86 prefix_group_id: bool = True, 

87 parent_group: TaskGroup | None = None, 

88 dag: DAG | None = None, 

89 default_args: dict[str, Any] | None = None, 

90 tooltip: str = "", 

91 ui_color: str = "CornflowerBlue", 

92 ui_fgcolor: str = "#000", 

93 add_suffix_on_collision: bool = False, 

94 ): 

95 from airflow.models.dag import DagContext 

96 

97 self.prefix_group_id = prefix_group_id 

98 self.default_args = copy.deepcopy(default_args or {}) 

99 

100 dag = dag or DagContext.get_current_dag() 

101 

102 if group_id is None: 

103 # This creates a root TaskGroup. 

104 if parent_group: 

105 raise AirflowException("Root TaskGroup cannot have parent_group") 

106 # used_group_ids is shared across all TaskGroups in the same DAG to keep track 

107 # of used group_id to avoid duplication. 

108 self.used_group_ids = set() 

109 self.dag = dag 

110 else: 

111 if prefix_group_id: 

112 # If group id is used as prefix, it should not contain spaces nor dots 

113 # because it is used as prefix in the task_id 

114 validate_group_key(group_id) 

115 else: 

116 if not isinstance(group_id, str): 

117 raise ValueError("group_id must be str") 

118 if not group_id: 

119 raise ValueError("group_id must not be empty") 

120 

121 if not parent_group and not dag: 

122 raise AirflowException("TaskGroup can only be used inside a dag") 

123 

124 parent_group = parent_group or TaskGroupContext.get_current_task_group(dag) 

125 if not parent_group: 

126 raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup") 

127 if dag is not parent_group.dag: 

128 raise RuntimeError( 

129 "Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag 

130 ) 

131 

132 self.used_group_ids = parent_group.used_group_ids 

133 

134 # if given group_id already used assign suffix by incrementing largest used suffix integer 

135 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 

136 self._group_id = group_id 

137 self._check_for_group_id_collisions(add_suffix_on_collision) 

138 

139 self.children: dict[str, DAGNode] = {} 

140 

141 if parent_group: 

142 parent_group.add(self) 

143 self._update_default_args(parent_group) 

144 

145 self.used_group_ids.add(self.group_id) 

146 if self.group_id: 

147 self.used_group_ids.add(self.downstream_join_id) 

148 self.used_group_ids.add(self.upstream_join_id) 

149 

150 self.tooltip = tooltip 

151 self.ui_color = ui_color 

152 self.ui_fgcolor = ui_fgcolor 

153 

154 # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately 

155 # so that we can optimize the number of edges when entire TaskGroups depend on each other. 

156 self.upstream_group_ids: set[str | None] = set() 

157 self.downstream_group_ids: set[str | None] = set() 

158 self.upstream_task_ids = set() 

159 self.downstream_task_ids = set() 

160 

161 def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): 

162 if self._group_id is None: 

163 return 

164 # if given group_id already used assign suffix by incrementing largest used suffix integer 

165 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 

166 if self._group_id in self.used_group_ids: 

167 if not add_suffix_on_collision: 

168 raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG") 

169 base = re.split(r"__\d+$", self._group_id)[0] 

170 suffixes = sorted( 

171 int(re.split(r"^.+__", used_group_id)[1]) 

172 for used_group_id in self.used_group_ids 

173 if used_group_id is not None and re.match(rf"^{base}__\d+$", used_group_id) 

174 ) 

175 if not suffixes: 

176 self._group_id += "__1" 

177 else: 

178 self._group_id = f"{base}__{suffixes[-1] + 1}" 

179 

180 def _update_default_args(self, parent_group: TaskGroup): 

181 if parent_group.default_args: 

182 self.default_args = {**self.default_args, **parent_group.default_args} 

183 

184 @classmethod 

185 def create_root(cls, dag: DAG) -> TaskGroup: 

186 """Create a root TaskGroup with no group_id or parent.""" 

187 return cls(group_id=None, dag=dag) 

188 

189 @property 

190 def node_id(self): 

191 return self.group_id 

192 

193 @property 

194 def is_root(self) -> bool: 

195 """Returns True if this TaskGroup is the root TaskGroup. Otherwise False.""" 

196 return not self.group_id 

197 

198 @property 

199 def parent_group(self) -> TaskGroup | None: 

200 return self.task_group 

201 

202 def __iter__(self): 

203 for child in self.children.values(): 

204 if isinstance(child, TaskGroup): 

205 yield from child 

206 else: 

207 yield child 

208 

209 def add(self, task: DAGNode) -> None: 

210 """Add a task to this TaskGroup. 

211 

212 :meta private: 

213 """ 

214 from airflow.models.abstractoperator import AbstractOperator 

215 

216 existing_tg = task.task_group 

217 if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self: 

218 raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id) 

219 

220 # Set the TG first, as setting it might change the return value of node_id! 

221 task.task_group = weakref.proxy(self) 

222 key = task.node_id 

223 

224 if key in self.children: 

225 node_type = "Task" if hasattr(task, "task_id") else "Task Group" 

226 raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG") 

227 

228 if isinstance(task, TaskGroup): 

229 if self.dag: 

230 if task.dag is not None and self.dag is not task.dag: 

231 raise RuntimeError( 

232 "Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag 

233 ) 

234 task.dag = self.dag 

235 if task.children: 

236 raise AirflowException("Cannot add a non-empty TaskGroup") 

237 

238 self.children[key] = task 

239 

240 def _remove(self, task: DAGNode) -> None: 

241 key = task.node_id 

242 

243 if key not in self.children: 

244 raise KeyError(f"Node id {key!r} not part of this task group") 

245 

246 self.used_group_ids.remove(key) 

247 del self.children[key] 

248 

249 @property 

250 def group_id(self) -> str | None: 

251 """group_id of this TaskGroup.""" 

252 if self.task_group and self.task_group.prefix_group_id and self.task_group._group_id: 

253 # defer to parent whether it adds a prefix 

254 return self.task_group.child_id(self._group_id) 

255 

256 return self._group_id 

257 

258 @property 

259 def label(self) -> str | None: 

260 """group_id excluding parent's group_id used as the node label in UI.""" 

261 return self._group_id 

262 

263 def update_relative( 

264 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None 

265 ) -> None: 

266 """ 

267 Overrides TaskMixin.update_relative. 

268 

269 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids 

270 accordingly so that we can reduce the number of edges when displaying Graph view. 

271 """ 

272 if isinstance(other, TaskGroup): 

273 # Handles setting relationship between a TaskGroup and another TaskGroup 

274 if upstream: 

275 parent, child = (self, other) 

276 if edge_modifier: 

277 edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id) 

278 else: 

279 parent, child = (other, self) 

280 if edge_modifier: 

281 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id) 

282 

283 parent.upstream_group_ids.add(child.group_id) 

284 child.downstream_group_ids.add(parent.group_id) 

285 else: 

286 # Handles setting relationship between a TaskGroup and a task 

287 for task in other.roots: 

288 if not isinstance(task, DAGNode): 

289 raise AirflowException( 

290 "Relationships can only be set between TaskGroup " 

291 f"or operators; received {task.__class__.__name__}" 

292 ) 

293 

294 # Do not set a relationship between a TaskGroup and a Label's roots 

295 if self == task: 

296 continue 

297 

298 if upstream: 

299 self.upstream_task_ids.add(task.node_id) 

300 if edge_modifier: 

301 edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id) 

302 else: 

303 self.downstream_task_ids.add(task.node_id) 

304 if edge_modifier: 

305 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id) 

306 

307 def _set_relatives( 

308 self, 

309 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

310 upstream: bool = False, 

311 edge_modifier: EdgeModifier | None = None, 

312 ) -> None: 

313 """ 

314 Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. 

315 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids. 

316 """ 

317 if not isinstance(task_or_task_list, Sequence): 

318 task_or_task_list = [task_or_task_list] 

319 

320 for task_like in task_or_task_list: 

321 self.update_relative(task_like, upstream, edge_modifier=edge_modifier) 

322 

323 if upstream: 

324 for task in self.get_roots(): 

325 task.set_upstream(task_or_task_list) 

326 else: 

327 for task in self.get_leaves(): 

328 task.set_downstream(task_or_task_list) 

329 

330 def __enter__(self) -> TaskGroup: 

331 TaskGroupContext.push_context_managed_task_group(self) 

332 return self 

333 

334 def __exit__(self, _type, _value, _tb): 

335 TaskGroupContext.pop_context_managed_task_group() 

336 

337 def has_task(self, task: BaseOperator) -> bool: 

338 """Returns True if this TaskGroup or its children TaskGroups contains the given task.""" 

339 if task.task_id in self.children: 

340 return True 

341 

342 return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup)) 

343 

344 @property 

345 def roots(self) -> list[BaseOperator]: 

346 """Required by TaskMixin.""" 

347 return list(self.get_roots()) 

348 

349 @property 

350 def leaves(self) -> list[BaseOperator]: 

351 """Required by TaskMixin.""" 

352 return list(self.get_leaves()) 

353 

354 def get_roots(self) -> Generator[BaseOperator, None, None]: 

355 """ 

356 Returns a generator of tasks that are root tasks, i.e. those with no upstream 

357 dependencies within the TaskGroup. 

358 """ 

359 for task in self: 

360 if not any(self.has_task(parent) for parent in task.get_direct_relatives(upstream=True)): 

361 yield task 

362 

363 def get_leaves(self) -> Generator[BaseOperator, None, None]: 

364 """ 

365 Returns a generator of tasks that are leaf tasks, i.e. those with no downstream 

366 dependencies within the TaskGroup. 

367 """ 

368 for task in self: 

369 if not any(self.has_task(child) for child in task.get_direct_relatives(upstream=False)): 

370 yield task 

371 

372 def child_id(self, label): 

373 """ 

374 Prefix label with group_id if prefix_group_id is True. Otherwise return the label 

375 as-is. 

376 """ 

377 if self.prefix_group_id: 

378 group_id = self.group_id 

379 if group_id: 

380 return f"{group_id}.{label}" 

381 

382 return label 

383 

384 @property 

385 def upstream_join_id(self) -> str: 

386 """ 

387 If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called 

388 upstream_join_id will be created in Graph view to join the outgoing edges from this 

389 TaskGroup to reduce the total number of edges needed to be displayed. 

390 """ 

391 return f"{self.group_id}.upstream_join_id" 

392 

393 @property 

394 def downstream_join_id(self) -> str: 

395 """ 

396 If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called 

397 downstream_join_id will be created in Graph view to join the outgoing edges from this 

398 TaskGroup to reduce the total number of edges needed to be displayed. 

399 """ 

400 return f"{self.group_id}.downstream_join_id" 

401 

402 def get_task_group_dict(self) -> dict[str, TaskGroup]: 

403 """Returns a flat dictionary of group_id: TaskGroup.""" 

404 task_group_map = {} 

405 

406 def build_map(task_group): 

407 if not isinstance(task_group, TaskGroup): 

408 return 

409 

410 task_group_map[task_group.group_id] = task_group 

411 

412 for child in task_group.children.values(): 

413 build_map(child) 

414 

415 build_map(self) 

416 return task_group_map 

417 

418 def get_child_by_label(self, label: str) -> DAGNode: 

419 """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix).""" 

420 return self.children[self.child_id(label)] 

421 

422 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

423 """Required by DAGNode.""" 

424 from airflow.serialization.serialized_objects import TaskGroupSerialization 

425 

426 return DagAttributeTypes.TASK_GROUP, TaskGroupSerialization.serialize_task_group(self) 

427 

428 def topological_sort(self, _include_subdag_tasks: bool = False): 

429 """ 

430 Sorts children in topographical order, such that a task comes after any of its 

431 upstream dependencies. 

432 

433 :return: list of tasks in topological order 

434 """ 

435 # This uses a modified version of Kahn's Topological Sort algorithm to 

436 # not have to pre-compute the "in-degree" of the nodes. 

437 from airflow.operators.subdag import SubDagOperator # Avoid circular import 

438 

439 graph_unsorted = copy.copy(self.children) 

440 

441 graph_sorted: list[DAGNode] = [] 

442 

443 # special case 

444 if len(self.children) == 0: 

445 return graph_sorted 

446 

447 # Run until the unsorted graph is empty. 

448 while graph_unsorted: 

449 # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain 

450 # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the 

451 # pair from the unsorted graph, and append it to the sorted graph. Note here that by using using 

452 # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify 

453 # the unsorted graph as we move through it. 

454 # 

455 # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved 

456 # during each pass through the graph. If not, we need to exit as the graph therefore can't be 

457 # sorted. 

458 acyclic = False 

459 for node in list(graph_unsorted.values()): 

460 for edge in node.upstream_list: 

461 if edge.node_id in graph_unsorted: 

462 break 

463 # Check for task's group is a child (or grand child) of this TG, 

464 tg = edge.task_group 

465 while tg: 

466 if tg.node_id in graph_unsorted: 

467 break 

468 tg = tg.task_group 

469 

470 if tg: 

471 # We are already going to visit that TG 

472 break 

473 else: 

474 acyclic = True 

475 del graph_unsorted[node.node_id] 

476 graph_sorted.append(node) 

477 if _include_subdag_tasks and isinstance(node, SubDagOperator): 

478 graph_sorted.extend( 

479 node.subdag.task_group.topological_sort(_include_subdag_tasks=True) 

480 ) 

481 

482 if not acyclic: 

483 raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}") 

484 

485 return graph_sorted 

486 

487 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

488 """Return mapped task groups in the hierarchy. 

489 

490 Groups are returned from the closest to the outmost. If *self* is a 

491 mapped task group, it is returned first. 

492 

493 :meta private: 

494 """ 

495 group: TaskGroup | None = self 

496 while group is not None: 

497 if isinstance(group, MappedTaskGroup): 

498 yield group 

499 group = group.task_group 

500 

501 def iter_tasks(self) -> Iterator[AbstractOperator]: 

502 """Returns an iterator of the child tasks.""" 

503 from airflow.models.abstractoperator import AbstractOperator 

504 

505 groups_to_visit = [self] 

506 

507 while groups_to_visit: 

508 visiting = groups_to_visit.pop(0) 

509 

510 for child in visiting.children.values(): 

511 if isinstance(child, AbstractOperator): 

512 yield child 

513 elif isinstance(child, TaskGroup): 

514 groups_to_visit.append(child) 

515 else: 

516 raise ValueError( 

517 f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}" 

518 ) 

519 

520 

521class MappedTaskGroup(TaskGroup): 

522 """A mapped task group. 

523 

524 This doesn't really do anything special, just holds some additional metadata 

525 for expansion later. 

526 

527 Don't instantiate this class directly; call *expand* or *expand_kwargs* on 

528 a ``@task_group`` function instead. 

529 """ 

530 

531 def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None: 

532 super().__init__(**kwargs) 

533 self._expand_input = expand_input 

534 for op, _ in expand_input.iter_references(): 

535 self.set_upstream(op) 

536 

537 def iter_mapped_dependencies(self) -> Iterator[Operator]: 

538 """Upstream dependencies that provide XComs used by this mapped task group.""" 

539 from airflow.models.xcom_arg import XComArg 

540 

541 for op, _ in XComArg.iter_xcom_references(self._expand_input): 

542 yield op 

543 

544 @cache 

545 def get_parse_time_mapped_ti_count(self) -> int: 

546 """Number of instances a task in this group should be mapped to, when a DAG run is created. 

547 

548 This only considers literal mapped arguments, and would return *None* 

549 when any non-literal values are used for mapping. 

550 

551 If this group is inside mapped task groups, all the nested counts are 

552 multiplied and accounted. 

553 

554 :meta private: 

555 

556 :raise NotFullyPopulated: If any non-literal mapped arguments are encountered. 

557 :return: The total number of mapped instances each task should have. 

558 """ 

559 return functools.reduce( 

560 operator.mul, 

561 (g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()), 

562 ) 

563 

564 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: 

565 """Number of instances a task in this group should be mapped to at run time. 

566 

567 This considers both literal and non-literal mapped arguments, and the 

568 result is therefore available when all depended tasks have finished. The 

569 return value should be identical to ``parse_time_mapped_ti_count`` if 

570 all mapped arguments are literal. 

571 

572 If this group is inside mapped task groups, all the nested counts are 

573 multiplied and accounted. 

574 

575 :meta private: 

576 

577 :raise NotFullyPopulated: If upstream tasks are not all complete yet. 

578 :return: Total number of mapped TIs this task should have. 

579 """ 

580 groups = self.iter_mapped_task_groups() 

581 return functools.reduce( 

582 operator.mul, 

583 (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), 

584 ) 

585 

586 

587class TaskGroupContext: 

588 """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" 

589 

590 _context_managed_task_group: TaskGroup | None = None 

591 _previous_context_managed_task_groups: list[TaskGroup] = [] 

592 

593 @classmethod 

594 def push_context_managed_task_group(cls, task_group: TaskGroup): 

595 """Push a TaskGroup into the list of managed TaskGroups.""" 

596 if cls._context_managed_task_group: 

597 cls._previous_context_managed_task_groups.append(cls._context_managed_task_group) 

598 cls._context_managed_task_group = task_group 

599 

600 @classmethod 

601 def pop_context_managed_task_group(cls) -> TaskGroup | None: 

602 """Pops the last TaskGroup from the list of manged TaskGroups and update the current TaskGroup.""" 

603 old_task_group = cls._context_managed_task_group 

604 if cls._previous_context_managed_task_groups: 

605 cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop() 

606 else: 

607 cls._context_managed_task_group = None 

608 return old_task_group 

609 

610 @classmethod 

611 def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None: 

612 """Get the current TaskGroup.""" 

613 from airflow.models.dag import DagContext 

614 

615 if not cls._context_managed_task_group: 

616 dag = dag or DagContext.get_current_dag() 

617 if dag: 

618 # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag. 

619 return dag.task_group 

620 

621 return cls._context_managed_task_group 

622 

623 

624def task_group_to_dict(task_item_or_group): 

625 """ 

626 Create a nested dict representation of this TaskGroup and its children used to construct 

627 the Graph. 

628 """ 

629 from airflow.models.abstractoperator import AbstractOperator 

630 

631 if isinstance(task_item_or_group, AbstractOperator): 

632 return { 

633 "id": task_item_or_group.task_id, 

634 "value": { 

635 "label": task_item_or_group.label, 

636 "labelStyle": f"fill:{task_item_or_group.ui_fgcolor};", 

637 "style": f"fill:{task_item_or_group.ui_color};", 

638 "rx": 5, 

639 "ry": 5, 

640 }, 

641 } 

642 task_group = task_item_or_group 

643 is_mapped = isinstance(task_group, MappedTaskGroup) 

644 children = [ 

645 task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label) 

646 ] 

647 

648 if task_group.upstream_group_ids or task_group.upstream_task_ids: 

649 children.append( 

650 { 

651 "id": task_group.upstream_join_id, 

652 "value": { 

653 "label": "", 

654 "labelStyle": f"fill:{task_group.ui_fgcolor};", 

655 "style": f"fill:{task_group.ui_color};", 

656 "shape": "circle", 

657 }, 

658 } 

659 ) 

660 

661 if task_group.downstream_group_ids or task_group.downstream_task_ids: 

662 # This is the join node used to reduce the number of edges between two TaskGroup. 

663 children.append( 

664 { 

665 "id": task_group.downstream_join_id, 

666 "value": { 

667 "label": "", 

668 "labelStyle": f"fill:{task_group.ui_fgcolor};", 

669 "style": f"fill:{task_group.ui_color};", 

670 "shape": "circle", 

671 }, 

672 } 

673 ) 

674 

675 return { 

676 "id": task_group.group_id, 

677 "value": { 

678 "label": task_group.label, 

679 "labelStyle": f"fill:{task_group.ui_fgcolor};", 

680 "style": f"fill:{task_group.ui_color}", 

681 "rx": 5, 

682 "ry": 5, 

683 "clusterLabelPos": "top", 

684 "tooltip": task_group.tooltip, 

685 "isMapped": is_mapped, 

686 }, 

687 "children": children, 

688 }