1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import copy
21import os
22from collections.abc import MutableMapping
23from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict, cast
24
25from typing_extensions import NotRequired
26
27if TYPE_CHECKING:
28 import jinja2
29 from pendulum import DateTime
30
31 from airflow.sdk.bases.operator import BaseOperator
32 from airflow.sdk.definitions.dag import DAG
33 from airflow.sdk.execution_time.context import InletEventsAccessors
34 from airflow.sdk.types import (
35 DagRunProtocol,
36 Operator,
37 OutletEventAccessorsProtocol,
38 RuntimeTaskInstanceProtocol,
39 )
40
41
42class Context(TypedDict, total=False):
43 """Jinja2 template context for task rendering."""
44
45 conn: Any
46 dag: DAG
47 dag_run: DagRunProtocol
48 data_interval_end: NotRequired[DateTime | None]
49 data_interval_start: NotRequired[DateTime | None]
50 outlet_events: OutletEventAccessorsProtocol
51 ds: str
52 ds_nodash: str
53 expanded_ti_count: NotRequired[int | None]
54 exception: NotRequired[None | str | BaseException]
55 inlets: list
56 inlet_events: InletEventsAccessors
57 logical_date: DateTime
58 macros: Any
59 map_index_template: NotRequired[str | None]
60 outlets: list
61 params: dict[str, Any]
62 prev_data_interval_start_success: NotRequired[DateTime | None]
63 prev_data_interval_end_success: NotRequired[DateTime | None]
64 prev_start_date_success: NotRequired[DateTime | None]
65 prev_end_date_success: NotRequired[DateTime | None]
66 reason: NotRequired[str | None]
67 run_id: str
68 start_date: DateTime
69 # TODO: Remove Operator from below once we have MappedOperator to the Task SDK
70 # and once we can remove context related code from the Scheduler/models.TaskInstance
71 task: BaseOperator | Operator
72 task_reschedule_count: int
73 task_instance: RuntimeTaskInstanceProtocol
74 task_instance_key_str: str
75 # `templates_dict` is only set in PythonOperator
76 templates_dict: NotRequired[dict[str, Any] | None]
77 test_mode: bool
78 ti: RuntimeTaskInstanceProtocol
79 # triggering_asset_events: Mapping[str, Collection[AssetEvent | AssetEventPydantic]]
80 triggering_asset_events: Any
81 try_number: NotRequired[int | None]
82 ts: str
83 ts_nodash: str
84 ts_nodash_with_tz: str
85 var: Any
86
87
88KNOWN_CONTEXT_KEYS: set[str] = set(Context.__annotations__.keys())
89
90
91def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
92 """
93 Merge parameters into an existing context.
94
95 Like ``dict.update()`` , this take the same parameters, and updates
96 ``context`` in-place.
97
98 This is implemented as a free function because the ``Context`` type is
99 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
100 functions.
101
102 :meta private:
103 """
104 if not context:
105 context = Context()
106
107 context.update(*args, **kwargs)
108
109
110def get_current_context() -> Context:
111 """
112 Retrieve the execution context dictionary without altering user method's signature.
113
114 This is the simplest method of retrieving the execution context dictionary.
115
116 **Old style:**
117
118 .. code:: python
119
120 def my_task(**context):
121 ti = context["ti"]
122
123 **New style:**
124
125 .. code:: python
126
127 from airflow.sdk import get_current_context
128
129
130 def my_task():
131 context = get_current_context()
132 ti = context["ti"]
133
134 Current context will only have value if this method was called after an operator
135 was starting to execute.
136 """
137 from airflow.sdk.definitions._internal.contextmanager import _get_current_context
138
139 return _get_current_context()
140
141
142class AirflowParsingContext(NamedTuple):
143 """
144 Context of parsing for the Dag.
145
146 If these values are not None, they will contain the specific Dag and Task ID that Airflow is requesting to
147 execute. You can use these for optimizing dynamically generated Dag files.
148
149 You can obtain the current values via :py:func:`.get_parsing_context`.
150 """
151
152 dag_id: str | None
153 task_id: str | None
154
155
156_AIRFLOW_PARSING_CONTEXT_DAG_ID = "_AIRFLOW_PARSING_CONTEXT_DAG_ID"
157_AIRFLOW_PARSING_CONTEXT_TASK_ID = "_AIRFLOW_PARSING_CONTEXT_TASK_ID"
158
159
160def get_parsing_context() -> AirflowParsingContext:
161 """Return the current (Dag) parsing context info."""
162 return AirflowParsingContext(
163 dag_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID),
164 task_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID),
165 )
166
167
168# The 'template' argument is typed as Any because the jinja2.Template is too
169# dynamic to be effectively type-checked.
170def render_template(template: Any, context: MutableMapping[str, Any], *, native: bool) -> Any:
171 """
172 Render a Jinja2 template with given Airflow context.
173
174 The default implementation of ``jinja2.Template.render()`` converts the
175 input context into dict eagerly many times, which triggers deprecation
176 messages in our custom context class. This takes the implementation apart
177 and retain the context mapping without resolving instead.
178
179 :param template: A Jinja2 template to render.
180 :param context: The Airflow task context to render the template with.
181 :param native: If set to *True*, render the template into a native type. A
182 Dag can enable this with ``render_template_as_native_obj=True``.
183 :returns: The render result.
184 """
185 context = copy.copy(context)
186 env = template.environment
187 if template.globals:
188 context.update((k, v) for k, v in template.globals.items() if k not in context)
189 try:
190 nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks))
191 except Exception:
192 env.handle_exception() # Rewrite traceback to point to the template.
193 if native:
194 import jinja2.nativetypes
195
196 return jinja2.nativetypes.native_concat(nodes)
197 return "".join(nodes)
198
199
200def render_template_as_native(template: jinja2.Template, context: Context) -> Any:
201 """Shorthand to ``render_template(native=True)`` with better typing support."""
202 return render_template(template, cast("MutableMapping[str, Any]", context), native=True)
203
204
205def render_template_to_string(template: jinja2.Template, context: Context) -> str:
206 """Shorthand to ``render_template(native=False)`` with better typing support."""
207 return render_template(template, cast("MutableMapping[str, Any]", context), native=False)