1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import copy
21import os
22from collections.abc import MutableMapping
23from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict, cast
24
25if TYPE_CHECKING:
26 import jinja2
27 from pendulum import DateTime
28
29 from airflow.sdk.bases.operator import BaseOperator
30 from airflow.sdk.definitions.dag import DAG
31 from airflow.sdk.execution_time.context import InletEventsAccessors
32 from airflow.sdk.types import (
33 DagRunProtocol,
34 Operator,
35 OutletEventAccessorsProtocol,
36 RuntimeTaskInstanceProtocol,
37 )
38
39
40class Context(TypedDict, total=False):
41 """Jinja2 template context for task rendering."""
42
43 conn: Any
44 dag: DAG
45 dag_run: DagRunProtocol
46 data_interval_end: DateTime | None
47 data_interval_start: DateTime | None
48 outlet_events: OutletEventAccessorsProtocol
49 ds: str
50 ds_nodash: str
51 expanded_ti_count: int | None
52 exception: None | str | BaseException
53 inlets: list
54 inlet_events: InletEventsAccessors
55 logical_date: DateTime
56 macros: Any
57 map_index_template: str | None
58 outlets: list
59 params: dict[str, Any]
60 prev_data_interval_start_success: DateTime | None
61 prev_data_interval_end_success: DateTime | None
62 prev_start_date_success: DateTime | None
63 prev_end_date_success: DateTime | None
64 reason: str | None
65 run_id: str
66 start_date: DateTime
67 # TODO: Remove Operator from below once we have MappedOperator to the Task SDK
68 # and once we can remove context related code from the Scheduler/models.TaskInstance
69 task: BaseOperator | Operator
70 task_reschedule_count: int
71 task_instance: RuntimeTaskInstanceProtocol
72 task_instance_key_str: str
73 # `templates_dict` is only set in PythonOperator
74 templates_dict: dict[str, Any] | None
75 test_mode: bool
76 ti: RuntimeTaskInstanceProtocol
77 # triggering_asset_events: Mapping[str, Collection[AssetEvent | AssetEventPydantic]]
78 triggering_asset_events: Any
79 try_number: int | None
80 ts: str
81 ts_nodash: str
82 ts_nodash_with_tz: str
83 var: Any
84
85
86KNOWN_CONTEXT_KEYS: set[str] = set(Context.__annotations__.keys())
87
88
89def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
90 """
91 Merge parameters into an existing context.
92
93 Like ``dict.update()`` , this take the same parameters, and updates
94 ``context`` in-place.
95
96 This is implemented as a free function because the ``Context`` type is
97 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
98 functions.
99
100 :meta private:
101 """
102 if not context:
103 context = Context()
104
105 context.update(*args, **kwargs)
106
107
108def get_current_context() -> Context:
109 """
110 Retrieve the execution context dictionary without altering user method's signature.
111
112 This is the simplest method of retrieving the execution context dictionary.
113
114 **Old style:**
115
116 .. code:: python
117
118 def my_task(**context):
119 ti = context["ti"]
120
121 **New style:**
122
123 .. code:: python
124
125 from airflow.sdk import get_current_context
126
127
128 def my_task():
129 context = get_current_context()
130 ti = context["ti"]
131
132 Current context will only have value if this method was called after an operator
133 was starting to execute.
134 """
135 from airflow.sdk.definitions._internal.contextmanager import _get_current_context
136
137 return _get_current_context()
138
139
140class AirflowParsingContext(NamedTuple):
141 """
142 Context of parsing for the Dag.
143
144 If these values are not None, they will contain the specific Dag and Task ID that Airflow is requesting to
145 execute. You can use these for optimizing dynamically generated Dag files.
146
147 You can obtain the current values via :py:func:`.get_parsing_context`.
148 """
149
150 dag_id: str | None
151 task_id: str | None
152
153
154_AIRFLOW_PARSING_CONTEXT_DAG_ID = "_AIRFLOW_PARSING_CONTEXT_DAG_ID"
155_AIRFLOW_PARSING_CONTEXT_TASK_ID = "_AIRFLOW_PARSING_CONTEXT_TASK_ID"
156
157
158def get_parsing_context() -> AirflowParsingContext:
159 """Return the current (Dag) parsing context info."""
160 return AirflowParsingContext(
161 dag_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID),
162 task_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID),
163 )
164
165
166# The 'template' argument is typed as Any because the jinja2.Template is too
167# dynamic to be effectively type-checked.
168def render_template(template: Any, context: MutableMapping[str, Any], *, native: bool) -> Any:
169 """
170 Render a Jinja2 template with given Airflow context.
171
172 The default implementation of ``jinja2.Template.render()`` converts the
173 input context into dict eagerly many times, which triggers deprecation
174 messages in our custom context class. This takes the implementation apart
175 and retain the context mapping without resolving instead.
176
177 :param template: A Jinja2 template to render.
178 :param context: The Airflow task context to render the template with.
179 :param native: If set to *True*, render the template into a native type. A
180 Dag can enable this with ``render_template_as_native_obj=True``.
181 :returns: The render result.
182 """
183 context = copy.copy(context)
184 env = template.environment
185 if template.globals:
186 context.update((k, v) for k, v in template.globals.items() if k not in context)
187 try:
188 nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks))
189 except Exception:
190 env.handle_exception() # Rewrite traceback to point to the template.
191 if native:
192 import jinja2.nativetypes
193
194 return jinja2.nativetypes.native_concat(nodes)
195 return "".join(nodes)
196
197
198def render_template_as_native(template: jinja2.Template, context: Context) -> Any:
199 """Shorthand to ``render_template(native=True)`` with better typing support."""
200 return render_template(template, cast("MutableMapping[str, Any]", context), native=True)
201
202
203def render_template_to_string(template: jinja2.Template, context: Context) -> str:
204 """Shorthand to ``render_template(native=False)`` with better typing support."""
205 return render_template(template, cast("MutableMapping[str, Any]", context), native=False)