1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import copy
21import itertools
22import re
23import signal
24from collections.abc import Callable, Generator, Iterable, MutableMapping
25from functools import cache
26from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
27from urllib.parse import urljoin
28
29from lazy_object_proxy import Proxy
30
31from airflow.configuration import conf
32from airflow.exceptions import AirflowException
33from airflow.serialization.definitions.notset import is_arg_set
34
35if TYPE_CHECKING:
36 from datetime import datetime
37 from typing import TypeGuard
38
39 import jinja2
40
41 from airflow.models.taskinstance import TaskInstance
42 from airflow.sdk.definitions.context import Context
43
44 CT = TypeVar("CT", str, datetime)
45
46KEY_REGEX = re.compile(r"^[\w.-]+$")
47GROUP_KEY_REGEX = re.compile(r"^[\w-]+$")
48CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)")
49
50T = TypeVar("T")
51S = TypeVar("S")
52
53
54def validate_key(k: str, max_length: int = 250):
55 """Validate value used as a key."""
56 if not isinstance(k, str):
57 raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
58 if len(k) > max_length:
59 raise AirflowException(f"The key: {k} has to be less than {max_length} characters")
60 if not KEY_REGEX.match(k):
61 raise AirflowException(
62 f"The key {k!r} has to be made of alphanumeric characters, dashes, "
63 f"dots and underscores exclusively"
64 )
65
66
67def ask_yesno(question: str, default: bool | None = None) -> bool:
68 """Get a yes or no answer from the user."""
69 yes = {"yes", "y"}
70 no = {"no", "n"}
71
72 print(question)
73 while True:
74 choice = input().lower()
75 if choice == "" and default is not None:
76 return default
77 if choice in yes:
78 return True
79 if choice in no:
80 return False
81 print("Please respond with y/yes or n/no.")
82
83
84def prompt_with_timeout(question: str, timeout: int, default: bool | None = None) -> bool:
85 """Ask the user a question and timeout if they don't respond."""
86
87 def handler(signum, frame):
88 raise AirflowException(f"Timeout {timeout}s reached")
89
90 signal.signal(signal.SIGALRM, handler)
91 signal.alarm(timeout)
92 try:
93 return ask_yesno(question, default)
94 finally:
95 signal.alarm(0)
96
97
98@overload
99def is_container(obj: None | int | Iterable[int] | range) -> TypeGuard[Iterable[int]]: ...
100
101
102@overload
103def is_container(obj: None | CT | Iterable[CT]) -> TypeGuard[Iterable[CT]]: ...
104
105
106def is_container(obj) -> bool:
107 """Test if an object is a container (iterable) but not a string."""
108 if isinstance(obj, Proxy):
109 # Proxy of any object is considered a container because it implements __iter__
110 # to forward the call to the lazily initialized object
111 # Unwrap Proxy before checking __iter__ to evaluate the proxied object
112 obj = obj.__wrapped__
113 return hasattr(obj, "__iter__") and not isinstance(obj, str)
114
115
116def chunks(items: list[T], chunk_size: int) -> Generator[list[T], None, None]:
117 """Yield successive chunks of a given size from a list of items."""
118 if chunk_size <= 0:
119 raise ValueError("Chunk size must be a positive integer")
120 for i in range(0, len(items), chunk_size):
121 yield items[i : i + chunk_size]
122
123
124def as_flattened_list(iterable: Iterable[Iterable[T]]) -> list[T]:
125 """
126 Return an iterable with one level flattened.
127
128 >>> as_flattened_list((("blue", "red"), ("green", "yellow", "pink")))
129 ['blue', 'red', 'green', 'yellow', 'pink']
130 """
131 return [e for i in iterable for e in i]
132
133
134def parse_template_string(template_string: str) -> tuple[str, None] | tuple[None, jinja2.Template]:
135 """Parse Jinja template string."""
136 import jinja2
137
138 if "{{" in template_string: # jinja mode
139 return None, jinja2.Template(template_string)
140 return template_string, None
141
142
143@cache
144def log_filename_template_renderer() -> Callable[..., str]:
145 template = conf.get("logging", "log_filename_template")
146
147 if "{{" in template:
148 import jinja2
149
150 return jinja2.Template(template).render
151
152 def f_str_format(ti: TaskInstance, try_number: int | None = None):
153 return template.format(
154 dag_id=ti.dag_id,
155 task_id=ti.task_id,
156 logical_date=ti.logical_date.isoformat(),
157 try_number=try_number or ti.try_number,
158 )
159
160 return f_str_format
161
162
163def _render_template_to_string(template: jinja2.Template, context: Context) -> str:
164 """
165 Render a Jinja template to string using the provided context.
166
167 This is a private utility function specifically for log filename rendering.
168 It ensures templates are rendered as strings rather than native Python objects.
169 """
170 return render_template(template, cast("MutableMapping[str, Any]", context), native=False)
171
172
173def render_log_filename(ti: TaskInstance, try_number, filename_template) -> str:
174 """
175 Given task instance, try_number, filename_template, return the rendered log filename.
176
177 :param ti: task instance
178 :param try_number: try_number of the task
179 :param filename_template: filename template, which can be jinja template or
180 python string template
181 """
182 filename_template, filename_jinja_template = parse_template_string(filename_template)
183 if filename_jinja_template:
184 jinja_context = ti.get_template_context()
185 jinja_context["try_number"] = try_number
186 return _render_template_to_string(filename_jinja_template, jinja_context)
187
188 return filename_template.format(
189 dag_id=ti.dag_id,
190 task_id=ti.task_id,
191 logical_date=ti.logical_date.isoformat(),
192 try_number=try_number,
193 )
194
195
196def convert_camel_to_snake(camel_str: str) -> str:
197 """Convert CamelCase to snake_case."""
198 return CAMELCASE_TO_SNAKE_CASE_REGEX.sub(r"_\1", camel_str).lower()
199
200
201def merge_dicts(dict1: dict, dict2: dict) -> dict:
202 """
203 Merge two dicts recursively, returning new dict (input dict is not mutated).
204
205 Lists are not concatenated. Items in dict2 overwrite those also found in dict1.
206 """
207 merged = dict1.copy()
208 for k, v in dict2.items():
209 if k in merged and isinstance(v, dict):
210 merged[k] = merge_dicts(merged.get(k, {}), v)
211 else:
212 merged[k] = v
213 return merged
214
215
216def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]:
217 """Use a predicate to partition entries into false entries and true entries."""
218 iter_1, iter_2 = itertools.tee(iterable)
219 return itertools.filterfalse(pred, iter_1), filter(pred, iter_2)
220
221
222def build_airflow_dagrun_url(dag_id: str, run_id: str) -> str:
223 """
224 Build airflow dagrun url using base_url and provided dag_id and run_id.
225
226 For example:
227 http://localhost:8080/dags/hi/runs/manual__2025-02-23T18:27:39.051358+00:00_RZa1at4Q
228 """
229 baseurl = conf.get("api", "base_url", fallback="/")
230 return urljoin(baseurl.rstrip("/") + "/", f"dags/{dag_id}/runs/{run_id}")
231
232
233# The 'template' argument is typed as Any because the jinja2.Template is too
234# dynamic to be effectively type-checked.
235def render_template(template: Any, context: MutableMapping[str, Any], *, native: bool) -> Any:
236 """
237 Render a Jinja2 template with given Airflow context.
238
239 The default implementation of ``jinja2.Template.render()`` converts the
240 input context into dict eagerly many times, which triggers deprecation
241 messages in our custom context class. This takes the implementation apart
242 and retain the context mapping without resolving instead.
243
244 :param template: A Jinja2 template to render.
245 :param context: The Airflow task context to render the template with.
246 :param native: If set to *True*, render the template into a native type. A
247 DAG can enable this with ``render_template_as_native_obj=True``.
248 :returns: The render result.
249 """
250 context = copy.copy(context)
251 env = template.environment
252 if template.globals:
253 context.update((k, v) for k, v in template.globals.items() if k not in context)
254 try:
255 nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks))
256 except Exception:
257 env.handle_exception() # Rewrite traceback to point to the template.
258 if native:
259 import jinja2.nativetypes
260
261 return jinja2.nativetypes.native_concat(nodes)
262 return "".join(nodes)
263
264
265def exactly_one(*args) -> bool:
266 """
267 Return True if exactly one of args is "truthy", and False otherwise.
268
269 If user supplies an iterable, we raise ValueError and force them to unpack.
270 """
271 if is_container(args[0]):
272 raise ValueError(
273 "Not supported for iterable args. Use `*` to unpack your iterable in the function call."
274 )
275 return sum(map(bool, args)) == 1
276
277
278def at_most_one(*args) -> bool:
279 """
280 Return True if at most one of args is "truthy", and False otherwise.
281
282 NOTSET is treated the same as None.
283
284 If user supplies an iterable, we raise ValueError and force them to unpack.
285 """
286 return sum(is_arg_set(a) and bool(a) for a in args) in (0, 1)
287
288
289def prune_dict(val: Any, mode="strict"):
290 """
291 Given dict ``val``, returns new dict based on ``val`` with all empty elements removed.
292
293 What constitutes "empty" is controlled by the ``mode`` parameter. If mode is 'strict'
294 then only ``None`` elements will be removed. If mode is ``truthy``, then element ``x``
295 will be removed if ``bool(x) is False``.
296 """
297
298 def is_empty(x):
299 if mode == "strict":
300 return x is None
301 if mode == "truthy":
302 return bool(x) is False
303 raise ValueError("allowable values for `mode` include 'truthy' and 'strict'")
304
305 if isinstance(val, dict):
306 new_dict = {}
307 for k, v in val.items():
308 if is_empty(v):
309 continue
310 if isinstance(v, (list, dict)):
311 new_val = prune_dict(v, mode=mode)
312 if not is_empty(new_val):
313 new_dict[k] = new_val
314 else:
315 new_dict[k] = v
316 return new_dict
317 if isinstance(val, list):
318 new_list = []
319 for v in val:
320 if is_empty(v):
321 continue
322 if isinstance(v, (list, dict)):
323 new_val = prune_dict(v, mode=mode)
324 if not is_empty(new_val):
325 new_list.append(new_val)
326 else:
327 new_list.append(v)
328 return new_list
329 return val
330
331
332def __getattr__(name: str):
333 """Provide backward compatibility for moved functions in this module."""
334 if name == "render_template_as_native":
335 import warnings
336
337 from airflow.sdk.definitions.context import render_template_as_native
338
339 warnings.warn(
340 "airflow.utils.helpers.render_template_as_native is deprecated. "
341 "Use airflow.sdk.definitions.context.render_template_as_native instead.",
342 DeprecationWarning,
343 stacklevel=2,
344 )
345 return render_template_as_native
346
347 if name == "prevent_duplicates":
348 import warnings
349
350 from airflow.sdk.definitions.mappedoperator import prevent_duplicates
351
352 warnings.warn(
353 "airflow.utils.helpers.prevent_duplicates is deprecated. "
354 "Use airflow.sdk.definitions.mappedoperator.prevent_duplicates instead.",
355 DeprecationWarning,
356 stacklevel=2,
357 )
358 return prevent_duplicates
359
360 if name == "render_template_to_string":
361 import warnings
362
363 from airflow.sdk.definitions.context import render_template_as_native
364
365 warnings.warn(
366 "airflow.utils.helpers.render_template_to_string is deprecated. "
367 "Use airflow.sdk.definitions.context.render_template_to_string instead.",
368 DeprecationWarning,
369 stacklevel=2,
370 )
371 return render_template_as_native
372
373 raise AttributeError(f"module '{__name__}' has no attribute '{name}'")