1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import copy
21import itertools
22import re
23import signal
24from collections.abc import Callable, Generator, Iterable, MutableMapping
25from functools import cache
26from typing import TYPE_CHECKING, Any, TypeVar, overload
27from urllib.parse import urljoin
28
29from lazy_object_proxy import Proxy
30
31from airflow.configuration import conf
32from airflow.exceptions import AirflowException
33from airflow.serialization.definitions.notset import is_arg_set
34
35if TYPE_CHECKING:
36 from datetime import datetime
37 from typing import TypeGuard
38
39 import jinja2
40
41 from airflow.models.taskinstance import TaskInstance
42
43 CT = TypeVar("CT", str, datetime)
44
45KEY_REGEX = re.compile(r"^[\w.-]+$")
46GROUP_KEY_REGEX = re.compile(r"^[\w-]+$")
47CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)")
48
49T = TypeVar("T")
50S = TypeVar("S")
51
52
53def validate_key(k: str, max_length: int = 250):
54 """Validate value used as a key."""
55 if not isinstance(k, str):
56 raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
57 if len(k) > max_length:
58 raise AirflowException(f"The key: {k} has to be less than {max_length} characters")
59 if not KEY_REGEX.match(k):
60 raise AirflowException(
61 f"The key {k!r} has to be made of alphanumeric characters, dashes, "
62 f"dots and underscores exclusively"
63 )
64
65
66def ask_yesno(question: str, default: bool | None = None) -> bool:
67 """Get a yes or no answer from the user."""
68 yes = {"yes", "y"}
69 no = {"no", "n"}
70
71 print(question)
72 while True:
73 choice = input().lower()
74 if choice == "" and default is not None:
75 return default
76 if choice in yes:
77 return True
78 if choice in no:
79 return False
80 print("Please respond with y/yes or n/no.")
81
82
83def prompt_with_timeout(question: str, timeout: int, default: bool | None = None) -> bool:
84 """Ask the user a question and timeout if they don't respond."""
85
86 def handler(signum, frame):
87 raise AirflowException(f"Timeout {timeout}s reached")
88
89 signal.signal(signal.SIGALRM, handler)
90 signal.alarm(timeout)
91 try:
92 return ask_yesno(question, default)
93 finally:
94 signal.alarm(0)
95
96
97@overload
98def is_container(obj: None | int | Iterable[int] | range) -> TypeGuard[Iterable[int]]: ...
99
100
101@overload
102def is_container(obj: None | CT | Iterable[CT]) -> TypeGuard[Iterable[CT]]: ...
103
104
105def is_container(obj) -> bool:
106 """Test if an object is a container (iterable) but not a string."""
107 if isinstance(obj, Proxy):
108 # Proxy of any object is considered a container because it implements __iter__
109 # to forward the call to the lazily initialized object
110 # Unwrap Proxy before checking __iter__ to evaluate the proxied object
111 obj = obj.__wrapped__
112 return hasattr(obj, "__iter__") and not isinstance(obj, str)
113
114
115def chunks(items: list[T], chunk_size: int) -> Generator[list[T], None, None]:
116 """Yield successive chunks of a given size from a list of items."""
117 if chunk_size <= 0:
118 raise ValueError("Chunk size must be a positive integer")
119 for i in range(0, len(items), chunk_size):
120 yield items[i : i + chunk_size]
121
122
123def as_flattened_list(iterable: Iterable[Iterable[T]]) -> list[T]:
124 """
125 Return an iterable with one level flattened.
126
127 >>> as_flattened_list((("blue", "red"), ("green", "yellow", "pink")))
128 ['blue', 'red', 'green', 'yellow', 'pink']
129 """
130 return [e for i in iterable for e in i]
131
132
133def parse_template_string(template_string: str) -> tuple[str, None] | tuple[None, jinja2.Template]:
134 """Parse Jinja template string."""
135 import jinja2
136
137 if "{{" in template_string: # jinja mode
138 return None, jinja2.Template(template_string)
139 return template_string, None
140
141
142@cache
143def log_filename_template_renderer() -> Callable[..., str]:
144 template = conf.get("logging", "log_filename_template")
145
146 if "{{" in template:
147 import jinja2
148
149 return jinja2.Template(template).render
150
151 def f_str_format(ti: TaskInstance, try_number: int | None = None):
152 return template.format(
153 dag_id=ti.dag_id,
154 task_id=ti.task_id,
155 logical_date=ti.logical_date.isoformat(),
156 try_number=try_number or ti.try_number,
157 )
158
159 return f_str_format
160
161
162def convert_camel_to_snake(camel_str: str) -> str:
163 """Convert CamelCase to snake_case."""
164 return CAMELCASE_TO_SNAKE_CASE_REGEX.sub(r"_\1", camel_str).lower()
165
166
167def merge_dicts(dict1: dict, dict2: dict) -> dict:
168 """
169 Merge two dicts recursively, returning new dict (input dict is not mutated).
170
171 Lists are not concatenated. Items in dict2 overwrite those also found in dict1.
172 """
173 merged = dict1.copy()
174 for k, v in dict2.items():
175 if k in merged and isinstance(v, dict):
176 merged[k] = merge_dicts(merged.get(k, {}), v)
177 else:
178 merged[k] = v
179 return merged
180
181
182def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]:
183 """Use a predicate to partition entries into false entries and true entries."""
184 iter_1, iter_2 = itertools.tee(iterable)
185 return itertools.filterfalse(pred, iter_1), filter(pred, iter_2)
186
187
188def build_airflow_dagrun_url(dag_id: str, run_id: str) -> str:
189 """
190 Build airflow dagrun url using base_url and provided dag_id and run_id.
191
192 For example:
193 http://localhost:8080/dags/hi/runs/manual__2025-02-23T18:27:39.051358+00:00_RZa1at4Q
194 """
195 baseurl = conf.get("api", "base_url", fallback="/")
196 return urljoin(baseurl.rstrip("/") + "/", f"dags/{dag_id}/runs/{run_id}")
197
198
199# The 'template' argument is typed as Any because the jinja2.Template is too
200# dynamic to be effectively type-checked.
201def render_template(template: Any, context: MutableMapping[str, Any], *, native: bool) -> Any:
202 """
203 Render a Jinja2 template with given Airflow context.
204
205 The default implementation of ``jinja2.Template.render()`` converts the
206 input context into dict eagerly many times, which triggers deprecation
207 messages in our custom context class. This takes the implementation apart
208 and retain the context mapping without resolving instead.
209
210 :param template: A Jinja2 template to render.
211 :param context: The Airflow task context to render the template with.
212 :param native: If set to *True*, render the template into a native type. A
213 DAG can enable this with ``render_template_as_native_obj=True``.
214 :returns: The render result.
215 """
216 context = copy.copy(context)
217 env = template.environment
218 if template.globals:
219 context.update((k, v) for k, v in template.globals.items() if k not in context)
220 try:
221 nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks))
222 except Exception:
223 env.handle_exception() # Rewrite traceback to point to the template.
224 if native:
225 import jinja2.nativetypes
226
227 return jinja2.nativetypes.native_concat(nodes)
228 return "".join(nodes)
229
230
231def exactly_one(*args) -> bool:
232 """
233 Return True if exactly one of args is "truthy", and False otherwise.
234
235 If user supplies an iterable, we raise ValueError and force them to unpack.
236 """
237 if is_container(args[0]):
238 raise ValueError(
239 "Not supported for iterable args. Use `*` to unpack your iterable in the function call."
240 )
241 return sum(map(bool, args)) == 1
242
243
244def at_most_one(*args) -> bool:
245 """
246 Return True if at most one of args is "truthy", and False otherwise.
247
248 NOTSET is treated the same as None.
249
250 If user supplies an iterable, we raise ValueError and force them to unpack.
251 """
252 return sum(is_arg_set(a) and bool(a) for a in args) in (0, 1)
253
254
255def prune_dict(val: Any, mode="strict"):
256 """
257 Given dict ``val``, returns new dict based on ``val`` with all empty elements removed.
258
259 What constitutes "empty" is controlled by the ``mode`` parameter. If mode is 'strict'
260 then only ``None`` elements will be removed. If mode is ``truthy``, then element ``x``
261 will be removed if ``bool(x) is False``.
262 """
263
264 def is_empty(x):
265 if mode == "strict":
266 return x is None
267 if mode == "truthy":
268 return bool(x) is False
269 raise ValueError("allowable values for `mode` include 'truthy' and 'strict'")
270
271 if isinstance(val, dict):
272 new_dict = {}
273 for k, v in val.items():
274 if is_empty(v):
275 continue
276 if isinstance(v, (list, dict)):
277 new_val = prune_dict(v, mode=mode)
278 if not is_empty(new_val):
279 new_dict[k] = new_val
280 else:
281 new_dict[k] = v
282 return new_dict
283 if isinstance(val, list):
284 new_list = []
285 for v in val:
286 if is_empty(v):
287 continue
288 if isinstance(v, (list, dict)):
289 new_val = prune_dict(v, mode=mode)
290 if not is_empty(new_val):
291 new_list.append(new_val)
292 else:
293 new_list.append(v)
294 return new_list
295 return val
296
297
298__deprecated_imports = {
299 "render_template_as_native": "airflow.sdk.definitions.context",
300 "render_template_to_string": "airflow.sdk.definitions.context",
301 "prevent_duplicates": "airflow.sdk.definitions.mappedoperator",
302}
303
304
305def __getattr__(name: str):
306 """Provide backward compatibility for moved functions in this module."""
307 try:
308 modpath = __deprecated_imports[name]
309 except KeyError:
310 raise AttributeError(f"module '{__name__}' has no attribute '{name}'") from None
311
312 import warnings
313
314 warnings.warn(
315 f"{__name__}.{name} is deprecated. Use {modpath}.{name} instead.",
316 DeprecationWarning,
317 stacklevel=2,
318 )
319 return getattr(__import__(modpath), name)