Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/abstractoperator.py: 27%

318 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import datetime 

21import inspect 

22from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence 

23 

24from airflow.compat.functools import cache, cached_property 

25from airflow.configuration import conf 

26from airflow.exceptions import AirflowException 

27from airflow.models.expandinput import NotFullyPopulated 

28from airflow.models.taskmixin import DAGNode 

29from airflow.utils.context import Context 

30from airflow.utils.helpers import render_template_as_native, render_template_to_string 

31from airflow.utils.log.logging_mixin import LoggingMixin 

32from airflow.utils.mixins import ResolveMixin 

33from airflow.utils.session import NEW_SESSION, provide_session 

34from airflow.utils.state import State, TaskInstanceState 

35from airflow.utils.task_group import MappedTaskGroup 

36from airflow.utils.trigger_rule import TriggerRule 

37from airflow.utils.weight_rule import WeightRule 

38 

39TaskStateChangeCallback = Callable[[Context], None] 

40 

41if TYPE_CHECKING: 

42 import jinja2 # Slow import. 

43 from sqlalchemy.orm import Session 

44 

45 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink 

46 from airflow.models.dag import DAG 

47 from airflow.models.mappedoperator import MappedOperator 

48 from airflow.models.operator import Operator 

49 from airflow.models.taskinstance import TaskInstance 

50 

51DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") 

52DEFAULT_POOL_SLOTS: int = 1 

53DEFAULT_PRIORITY_WEIGHT: int = 1 

54DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") 

55DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = conf.getboolean( 

56 "scheduler", "ignore_first_depends_on_past_by_default" 

57) 

58DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) 

59DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( 

60 seconds=conf.getint("core", "default_task_retry_delay", fallback=300) 

61) 

62MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) 

63 

64DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( 

65 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) 

66) 

67DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS 

68DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( 

69 "core", "default_task_execution_timeout" 

70) 

71 

72 

73class NotMapped(Exception): 

74 """Raise if a task is neither mapped nor has any parent mapped groups.""" 

75 

76 

77class AbstractOperator(LoggingMixin, DAGNode): 

78 """Common implementation for operators, including unmapped and mapped. 

79 

80 This base class is more about sharing implementations, not defining a common 

81 interface. Unfortunately it's difficult to use this as the common base class 

82 for typing due to BaseOperator carrying too much historical baggage. 

83 

84 The union type ``from airflow.models.operator import Operator`` is easier 

85 to use for typing purposes. 

86 

87 :meta private: 

88 """ 

89 

90 operator_class: type[BaseOperator] | dict[str, Any] 

91 

92 weight_rule: str 

93 priority_weight: int 

94 

95 # Defines the operator level extra links. 

96 operator_extra_links: Collection[BaseOperatorLink] 

97 # For derived classes to define which fields will get jinjaified. 

98 template_fields: Collection[str] 

99 # Defines which files extensions to look for in the templated fields. 

100 template_ext: Sequence[str] 

101 

102 owner: str 

103 task_id: str 

104 

105 outlets: list 

106 inlets: list 

107 

108 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( 

109 ( 

110 "log", 

111 "dag", # We show dag_id, don't need to show this too 

112 "node_id", # Duplicates task_id 

113 "task_group", # Doesn't have a useful repr, no point showing in UI 

114 "inherits_from_empty_operator", # impl detail 

115 # For compatibility with TG, for operators these are just the current task, no point showing 

116 "roots", 

117 "leaves", 

118 # These lists are already shown via *_task_ids 

119 "upstream_list", 

120 "downstream_list", 

121 # Not useful, implementation detail, already shown elsewhere 

122 "global_operator_extra_link_dict", 

123 "operator_extra_link_dict", 

124 ) 

125 ) 

126 

127 def get_dag(self) -> DAG | None: 

128 raise NotImplementedError() 

129 

130 @property 

131 def task_type(self) -> str: 

132 raise NotImplementedError() 

133 

134 @property 

135 def operator_name(self) -> str: 

136 raise NotImplementedError() 

137 

138 @property 

139 def inherits_from_empty_operator(self) -> bool: 

140 raise NotImplementedError() 

141 

142 @property 

143 def dag_id(self) -> str: 

144 """Returns dag id if it has one or an adhoc + owner""" 

145 dag = self.get_dag() 

146 if dag: 

147 return dag.dag_id 

148 return f"adhoc_{self.owner}" 

149 

150 @property 

151 def node_id(self) -> str: 

152 return self.task_id 

153 

154 def get_template_env(self) -> jinja2.Environment: 

155 """Fetch a Jinja template environment from the DAG or instantiate empty environment if no DAG.""" 

156 # This is imported locally since Jinja2 is heavy and we don't need it 

157 # for most of the functionalities. It is imported by get_template_env() 

158 # though, so we don't need to put this after the 'if dag' check. 

159 from airflow.templates import SandboxedEnvironment 

160 

161 dag = self.get_dag() 

162 if dag: 

163 return dag.get_template_env(force_sandboxed=False) 

164 return SandboxedEnvironment(cache_size=0) 

165 

166 def prepare_template(self) -> None: 

167 """Hook triggered after the templated fields get replaced by their content. 

168 

169 If you need your operator to alter the content of the file before the 

170 template is rendered, it should override this method to do so. 

171 """ 

172 

173 def resolve_template_files(self) -> None: 

174 """Getting the content of files for template_field / template_ext.""" 

175 if self.template_ext: 

176 for field in self.template_fields: 

177 content = getattr(self, field, None) 

178 if content is None: 

179 continue 

180 elif isinstance(content, str) and any(content.endswith(ext) for ext in self.template_ext): 

181 env = self.get_template_env() 

182 try: 

183 setattr(self, field, env.loader.get_source(env, content)[0]) # type: ignore 

184 except Exception: 

185 self.log.exception("Failed to resolve template field %r", field) 

186 elif isinstance(content, list): 

187 env = self.get_template_env() 

188 for i, item in enumerate(content): 

189 if isinstance(item, str) and any(item.endswith(ext) for ext in self.template_ext): 

190 try: 

191 content[i] = env.loader.get_source(env, item)[0] # type: ignore 

192 except Exception: 

193 self.log.exception("Failed to get source %s", item) 

194 self.prepare_template() 

195 

196 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: 

197 """Get direct relative IDs to the current task, upstream or downstream.""" 

198 if upstream: 

199 return self.upstream_task_ids 

200 return self.downstream_task_ids 

201 

202 def get_flat_relative_ids( 

203 self, 

204 upstream: bool = False, 

205 found_descendants: set[str] | None = None, 

206 ) -> set[str]: 

207 """Get a flat set of relative IDs, upstream or downstream.""" 

208 dag = self.get_dag() 

209 if not dag: 

210 return set() 

211 

212 if found_descendants is None: 

213 found_descendants = set() 

214 

215 task_ids_to_trace = self.get_direct_relative_ids(upstream) 

216 while task_ids_to_trace: 

217 task_ids_to_trace_next: set[str] = set() 

218 for task_id in task_ids_to_trace: 

219 if task_id in found_descendants: 

220 continue 

221 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) 

222 found_descendants.add(task_id) 

223 task_ids_to_trace = task_ids_to_trace_next 

224 

225 return found_descendants 

226 

227 def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]: 

228 """Get a flat list of relatives, either upstream or downstream.""" 

229 dag = self.get_dag() 

230 if not dag: 

231 return set() 

232 return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream)] 

233 

234 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

235 """Return mapped nodes that are direct dependencies of the current task. 

236 

237 For now, this walks the entire DAG to find mapped nodes that has this 

238 current task as an upstream. We cannot use ``downstream_list`` since it 

239 only contains operators, not task groups. In the future, we should 

240 provide a way to record an DAG node's all downstream nodes instead. 

241 

242 Note that this does not guarantee the returned tasks actually use the 

243 current task for task mapping, but only checks those task are mapped 

244 operators, and are downstreams of the current task. 

245 

246 To get a list of tasks that uses the current task for task mapping, use 

247 :meth:`iter_mapped_dependants` instead. 

248 """ 

249 from airflow.models.mappedoperator import MappedOperator 

250 from airflow.utils.task_group import TaskGroup 

251 

252 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: 

253 """Recursively walk children in a task group. 

254 

255 This yields all direct children (including both tasks and task 

256 groups), and all children of any task groups. 

257 """ 

258 for key, child in group.children.items(): 

259 yield key, child 

260 if isinstance(child, TaskGroup): 

261 yield from _walk_group(child) 

262 

263 dag = self.get_dag() 

264 if not dag: 

265 raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG") 

266 for key, child in _walk_group(dag.task_group): 

267 if key == self.node_id: 

268 continue 

269 if not isinstance(child, (MappedOperator, MappedTaskGroup)): 

270 continue 

271 if self.node_id in child.upstream_task_ids: 

272 yield child 

273 

274 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

275 """Return mapped nodes that depend on the current task the expansion. 

276 

277 For now, this walks the entire DAG to find mapped nodes that has this 

278 current task as an upstream. We cannot use ``downstream_list`` since it 

279 only contains operators, not task groups. In the future, we should 

280 provide a way to record an DAG node's all downstream nodes instead. 

281 """ 

282 return ( 

283 downstream 

284 for downstream in self._iter_all_mapped_downstreams() 

285 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) 

286 ) 

287 

288 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

289 """Return mapped task groups this task belongs to. 

290 

291 Groups are returned from the closest to the outmost. 

292 

293 :meta private: 

294 """ 

295 parent = self.task_group 

296 while parent is not None: 

297 if isinstance(parent, MappedTaskGroup): 

298 yield parent 

299 parent = parent.task_group 

300 

301 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: 

302 """:meta private:""" 

303 return next(self.iter_mapped_task_groups(), None) 

304 

305 def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator: 

306 """Get the "normal" operator from current abstract operator. 

307 

308 MappedOperator uses this to unmap itself based on the map index. A non- 

309 mapped operator (i.e. BaseOperator subclass) simply returns itself. 

310 

311 :meta private: 

312 """ 

313 raise NotImplementedError() 

314 

315 @property 

316 def priority_weight_total(self) -> int: 

317 """ 

318 Total priority weight for the task. It might include all upstream or downstream tasks. 

319 

320 Depending on the weight rule: 

321 

322 - WeightRule.ABSOLUTE - only own weight 

323 - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks 

324 - WeightRule.UPSTREAM - adds priority weight of all upstream tasks 

325 """ 

326 if self.weight_rule == WeightRule.ABSOLUTE: 

327 return self.priority_weight 

328 elif self.weight_rule == WeightRule.DOWNSTREAM: 

329 upstream = False 

330 elif self.weight_rule == WeightRule.UPSTREAM: 

331 upstream = True 

332 else: 

333 upstream = False 

334 dag = self.get_dag() 

335 if dag is None: 

336 return self.priority_weight 

337 return self.priority_weight + sum( 

338 dag.task_dict[task_id].priority_weight 

339 for task_id in self.get_flat_relative_ids(upstream=upstream) 

340 ) 

341 

342 @cached_property 

343 def operator_extra_link_dict(self) -> dict[str, Any]: 

344 """Returns dictionary of all extra links for the operator""" 

345 op_extra_links_from_plugin: dict[str, Any] = {} 

346 from airflow import plugins_manager 

347 

348 plugins_manager.initialize_extra_operators_links_plugins() 

349 if plugins_manager.operator_extra_links is None: 

350 raise AirflowException("Can't load operators") 

351 for ope in plugins_manager.operator_extra_links: 

352 if ope.operators and self.operator_class in ope.operators: 

353 op_extra_links_from_plugin.update({ope.name: ope}) 

354 

355 operator_extra_links_all = {link.name: link for link in self.operator_extra_links} 

356 # Extra links defined in Plugins overrides operator links defined in operator 

357 operator_extra_links_all.update(op_extra_links_from_plugin) 

358 

359 return operator_extra_links_all 

360 

361 @cached_property 

362 def global_operator_extra_link_dict(self) -> dict[str, Any]: 

363 """Returns dictionary of all global extra links""" 

364 from airflow import plugins_manager 

365 

366 plugins_manager.initialize_extra_operators_links_plugins() 

367 if plugins_manager.global_operator_extra_links is None: 

368 raise AirflowException("Can't load operators") 

369 return {link.name: link for link in plugins_manager.global_operator_extra_links} 

370 

371 @cached_property 

372 def extra_links(self) -> list[str]: 

373 return list(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict)) 

374 

375 def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None: 

376 """For an operator, gets the URLs that the ``extra_links`` entry points to. 

377 

378 :meta private: 

379 

380 :raise ValueError: The error message of a ValueError will be passed on through to 

381 the fronted to show up as a tooltip on the disabled link. 

382 :param ti: The TaskInstance for the URL being searched for. 

383 :param link_name: The name of the link we're looking for the URL for. Should be 

384 one of the options specified in ``extra_links``. 

385 """ 

386 link: BaseOperatorLink | None = self.operator_extra_link_dict.get(link_name) 

387 if not link: 

388 link = self.global_operator_extra_link_dict.get(link_name) 

389 if not link: 

390 return None 

391 

392 parameters = inspect.signature(link.get_link).parameters 

393 old_signature = all(name != "ti_key" for name, p in parameters.items() if p.kind != p.VAR_KEYWORD) 

394 

395 if old_signature: 

396 return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc] 

397 return link.get_link(self.unmap(None), ti_key=ti.key) 

398 

399 @cache 

400 def get_parse_time_mapped_ti_count(self) -> int: 

401 """Number of mapped task instances that can be created on DAG run creation. 

402 

403 This only considers literal mapped arguments, and would return *None* 

404 when any non-literal values are used for mapping. 

405 

406 :raise NotFullyPopulated: If non-literal mapped arguments are encountered. 

407 :raise NotMapped: If the operator is neither mapped, nor has any parent 

408 mapped task groups. 

409 :return: Total number of mapped TIs this task should have. 

410 """ 

411 group = self.get_closest_mapped_task_group() 

412 if group is None: 

413 raise NotMapped 

414 return group.get_parse_time_mapped_ti_count() 

415 

416 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: 

417 """Number of mapped TaskInstances that can be created at run time. 

418 

419 This considers both literal and non-literal mapped arguments, and the 

420 result is therefore available when all depended tasks have finished. The 

421 return value should be identical to ``parse_time_mapped_ti_count`` if 

422 all mapped arguments are literal. 

423 

424 :raise NotFullyPopulated: If upstream tasks are not all complete yet. 

425 :raise NotMapped: If the operator is neither mapped, nor has any parent 

426 mapped task groups. 

427 :return: Total number of mapped TIs this task should have. 

428 """ 

429 group = self.get_closest_mapped_task_group() 

430 if group is None: 

431 raise NotMapped 

432 return group.get_mapped_ti_count(run_id, session=session) 

433 

434 def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]: 

435 """Create the mapped task instances for mapped task. 

436 

437 :raise NotMapped: If this task does not need expansion. 

438 :return: The newly created mapped task instances (if any) in ascending 

439 order by map index, and the maximum map index value. 

440 """ 

441 from sqlalchemy import func, or_ 

442 

443 from airflow.models.baseoperator import BaseOperator 

444 from airflow.models.mappedoperator import MappedOperator 

445 from airflow.models.taskinstance import TaskInstance 

446 from airflow.settings import task_instance_mutation_hook 

447 

448 if not isinstance(self, (BaseOperator, MappedOperator)): 

449 raise RuntimeError(f"cannot expand unrecognized operator type {type(self).__name__}") 

450 

451 try: 

452 total_length: int | None = self.get_mapped_ti_count(run_id, session=session) 

453 except NotFullyPopulated as e: 

454 # It's possible that the upstream tasks are not yet done, but we 

455 # don't have upstream of upstreams in partial DAGs (possible in the 

456 # mini-scheduler), so we ignore this exception. 

457 if not self.dag or not self.dag.partial: 

458 self.log.error( 

459 "Cannot expand %r for run %s; missing upstream values: %s", 

460 self, 

461 run_id, 

462 sorted(e.missing), 

463 ) 

464 total_length = None 

465 

466 state: TaskInstanceState | None = None 

467 unmapped_ti: TaskInstance | None = ( 

468 session.query(TaskInstance) 

469 .filter( 

470 TaskInstance.dag_id == self.dag_id, 

471 TaskInstance.task_id == self.task_id, 

472 TaskInstance.run_id == run_id, 

473 TaskInstance.map_index == -1, 

474 or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), 

475 ) 

476 .one_or_none() 

477 ) 

478 

479 all_expanded_tis: list[TaskInstance] = [] 

480 

481 if unmapped_ti: 

482 # The unmapped task instance still exists and is unfinished, i.e. we 

483 # haven't tried to run it before. 

484 if total_length is None: 

485 # If the DAG is partial, it's likely that the upstream tasks 

486 # are not done yet, so the task can't fail yet. 

487 if not self.dag or not self.dag.partial: 

488 unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED 

489 elif total_length < 1: 

490 # If the upstream maps this to a zero-length value, simply mark 

491 # the unmapped task instance as SKIPPED (if needed). 

492 self.log.info( 

493 "Marking %s as SKIPPED since the map has %d values to expand", 

494 unmapped_ti, 

495 total_length, 

496 ) 

497 unmapped_ti.state = TaskInstanceState.SKIPPED 

498 else: 

499 zero_index_ti_exists = ( 

500 session.query(TaskInstance) 

501 .filter( 

502 TaskInstance.dag_id == self.dag_id, 

503 TaskInstance.task_id == self.task_id, 

504 TaskInstance.run_id == run_id, 

505 TaskInstance.map_index == 0, 

506 ) 

507 .count() 

508 > 0 

509 ) 

510 if not zero_index_ti_exists: 

511 # Otherwise convert this into the first mapped index, and create 

512 # TaskInstance for other indexes. 

513 unmapped_ti.map_index = 0 

514 self.log.debug("Updated in place to become %s", unmapped_ti) 

515 all_expanded_tis.append(unmapped_ti) 

516 session.flush() 

517 else: 

518 self.log.debug("Deleting the original task instance: %s", unmapped_ti) 

519 session.delete(unmapped_ti) 

520 state = unmapped_ti.state 

521 

522 if total_length is None or total_length < 1: 

523 # Nothing to fixup. 

524 indexes_to_map: Iterable[int] = () 

525 else: 

526 # Only create "missing" ones. 

527 current_max_mapping = ( 

528 session.query(func.max(TaskInstance.map_index)) 

529 .filter( 

530 TaskInstance.dag_id == self.dag_id, 

531 TaskInstance.task_id == self.task_id, 

532 TaskInstance.run_id == run_id, 

533 ) 

534 .scalar() 

535 ) 

536 indexes_to_map = range(current_max_mapping + 1, total_length) 

537 

538 for index in indexes_to_map: 

539 # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. 

540 ti = TaskInstance(self, run_id=run_id, map_index=index, state=state) 

541 self.log.debug("Expanding TIs upserted %s", ti) 

542 task_instance_mutation_hook(ti) 

543 ti = session.merge(ti) 

544 ti.refresh_from_task(self) # session.merge() loses task information. 

545 all_expanded_tis.append(ti) 

546 

547 # Coerce the None case to 0 -- these two are almost treated identically, 

548 # except the unmapped ti (if exists) is marked to different states. 

549 total_expanded_ti_count = total_length or 0 

550 

551 # Any (old) task instances with inapplicable indexes (>= the total 

552 # number we need) are set to "REMOVED". 

553 session.query(TaskInstance).filter( 

554 TaskInstance.dag_id == self.dag_id, 

555 TaskInstance.task_id == self.task_id, 

556 TaskInstance.run_id == run_id, 

557 TaskInstance.map_index >= total_expanded_ti_count, 

558 ).update({TaskInstance.state: TaskInstanceState.REMOVED}) 

559 

560 session.flush() 

561 return all_expanded_tis, total_expanded_ti_count - 1 

562 

563 def render_template_fields( 

564 self, 

565 context: Context, 

566 jinja_env: jinja2.Environment | None = None, 

567 ) -> None: 

568 """Template all attributes listed in *self.template_fields*. 

569 

570 If the operator is mapped, this should return the unmapped, fully 

571 rendered, and map-expanded operator. The mapped operator should not be 

572 modified. However, *context* may be modified in-place to reference the 

573 unmapped operator for template rendering. 

574 

575 If the operator is not mapped, this should modify the operator in-place. 

576 """ 

577 raise NotImplementedError() 

578 

579 @provide_session 

580 def _do_render_template_fields( 

581 self, 

582 parent: Any, 

583 template_fields: Iterable[str], 

584 context: Context, 

585 jinja_env: jinja2.Environment, 

586 seen_oids: set[int], 

587 *, 

588 session: Session = NEW_SESSION, 

589 ) -> None: 

590 for attr_name in template_fields: 

591 try: 

592 value = getattr(parent, attr_name) 

593 except AttributeError: 

594 raise AttributeError( 

595 f"{attr_name!r} is configured as a template field " 

596 f"but {parent.task_type} does not have this attribute." 

597 ) 

598 if not value: 

599 continue 

600 try: 

601 rendered_content = self.render_template( 

602 value, 

603 context, 

604 jinja_env, 

605 seen_oids, 

606 ) 

607 except Exception: 

608 self.log.exception( 

609 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", 

610 self.task_id, 

611 attr_name, 

612 value, 

613 ) 

614 raise 

615 else: 

616 setattr(parent, attr_name, rendered_content) 

617 

618 def render_template( 

619 self, 

620 content: Any, 

621 context: Context, 

622 jinja_env: jinja2.Environment | None = None, 

623 seen_oids: set[int] | None = None, 

624 ) -> Any: 

625 """Render a templated string. 

626 

627 If *content* is a collection holding multiple templated strings, strings 

628 in the collection will be templated recursively. 

629 

630 :param content: Content to template. Only strings can be templated (may 

631 be inside a collection). 

632 :param context: Dict with values to apply on templated content 

633 :param jinja_env: Jinja environment. Can be provided to avoid 

634 re-creating Jinja environments during recursion. 

635 :param seen_oids: template fields already rendered (to avoid 

636 *RecursionError* on circular dependencies) 

637 :return: Templated content 

638 """ 

639 # "content" is a bad name, but we're stuck to it being public API. 

640 value = content 

641 del content 

642 

643 if seen_oids is not None: 

644 oids = seen_oids 

645 else: 

646 oids = set() 

647 

648 if id(value) in oids: 

649 return value 

650 

651 if not jinja_env: 

652 jinja_env = self.get_template_env() 

653 

654 if isinstance(value, str): 

655 if any(value.endswith(ext) for ext in self.template_ext): # A filepath. 

656 template = jinja_env.get_template(value) 

657 else: 

658 template = jinja_env.from_string(value) 

659 dag = self.get_dag() 

660 if dag and dag.render_template_as_native_obj: 

661 return render_template_as_native(template, context) 

662 return render_template_to_string(template, context) 

663 

664 if isinstance(value, ResolveMixin): 

665 return value.resolve(context) 

666 

667 # Fast path for common built-in collections. 

668 if value.__class__ is tuple: 

669 return tuple(self.render_template(element, context, jinja_env, oids) for element in value) 

670 elif isinstance(value, tuple): # Special case for named tuples. 

671 return value.__class__(*(self.render_template(el, context, jinja_env, oids) for el in value)) 

672 elif isinstance(value, list): 

673 return [self.render_template(element, context, jinja_env, oids) for element in value] 

674 elif isinstance(value, dict): 

675 return {k: self.render_template(v, context, jinja_env, oids) for k, v in value.items()} 

676 elif isinstance(value, set): 

677 return {self.render_template(element, context, jinja_env, oids) for element in value} 

678 

679 # More complex collections. 

680 self._render_nested_template_fields(value, context, jinja_env, oids) 

681 return value 

682 

683 def _render_nested_template_fields( 

684 self, 

685 value: Any, 

686 context: Context, 

687 jinja_env: jinja2.Environment, 

688 seen_oids: set[int], 

689 ) -> None: 

690 if id(value) in seen_oids: 

691 return 

692 seen_oids.add(id(value)) 

693 try: 

694 nested_template_fields = value.template_fields 

695 except AttributeError: 

696 # content has no inner template fields 

697 return 

698 self._do_render_template_fields(value, nested_template_fields, context, jinja_env, seen_oids)