Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/utils/helpers.py: 31%
184 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import copy
21import re
22import signal
23import warnings
24from datetime import datetime
25from functools import reduce
26from itertools import filterfalse, tee
27from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Mapping, MutableMapping, TypeVar, cast
29from airflow.configuration import conf
30from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
31from airflow.utils.context import Context
32from airflow.utils.module_loading import import_string
33from airflow.utils.types import NOTSET
35if TYPE_CHECKING:
36 import jinja2
38 from airflow.models.taskinstance import TaskInstance
40KEY_REGEX = re.compile(r"^[\w.-]+$")
41GROUP_KEY_REGEX = re.compile(r"^[\w-]+$")
42CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)")
44T = TypeVar("T")
45S = TypeVar("S")
48def validate_key(k: str, max_length: int = 250):
49 """Validates value used as a key."""
50 if not isinstance(k, str):
51 raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
52 if len(k) > max_length:
53 raise AirflowException(f"The key has to be less than {max_length} characters")
54 if not KEY_REGEX.match(k):
55 raise AirflowException(
56 f"The key {k!r} has to be made of alphanumeric characters, dashes, "
57 f"dots and underscores exclusively"
58 )
61def validate_group_key(k: str, max_length: int = 200):
62 """Validates value used as a group key."""
63 if not isinstance(k, str):
64 raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
65 if len(k) > max_length:
66 raise AirflowException(f"The key has to be less than {max_length} characters")
67 if not GROUP_KEY_REGEX.match(k):
68 raise AirflowException(
69 f"The key {k!r} has to be made of alphanumeric characters, dashes and underscores exclusively"
70 )
73def alchemy_to_dict(obj: Any) -> dict | None:
74 """Transforms a SQLAlchemy model instance into a dictionary."""
75 if not obj:
76 return None
77 output = {}
78 for col in obj.__table__.columns:
79 value = getattr(obj, col.name)
80 if isinstance(value, datetime):
81 value = value.isoformat()
82 output[col.name] = value
83 return output
86def ask_yesno(question: str, default: bool | None = None) -> bool:
87 """Helper to get a yes or no answer from the user."""
88 yes = {"yes", "y"}
89 no = {"no", "n"}
91 print(question)
92 while True:
93 choice = input().lower()
94 if choice == "" and default is not None:
95 return default
96 if choice in yes:
97 return True
98 if choice in no:
99 return False
100 print("Please respond with y/yes or n/no.")
103def prompt_with_timeout(question: str, timeout: int, default: bool | None = None) -> bool:
104 """Ask the user a question and timeout if they don't respond."""
106 def handler(signum, frame):
107 raise AirflowException(f"Timeout {timeout}s reached")
109 signal.signal(signal.SIGALRM, handler)
110 signal.alarm(timeout)
111 try:
112 return ask_yesno(question, default)
113 finally:
114 signal.alarm(0)
117def is_container(obj: Any) -> bool:
118 """Test if an object is a container (iterable) but not a string."""
119 return hasattr(obj, "__iter__") and not isinstance(obj, str)
122def as_tuple(obj: Any) -> tuple:
123 """
124 If obj is a container, returns obj as a tuple.
125 Otherwise, returns a tuple containing obj.
126 """
127 if is_container(obj):
128 return tuple(obj)
129 else:
130 return tuple([obj])
133def chunks(items: list[T], chunk_size: int) -> Generator[list[T], None, None]:
134 """Yield successive chunks of a given size from a list of items."""
135 if chunk_size <= 0:
136 raise ValueError("Chunk size must be a positive integer")
137 for i in range(0, len(items), chunk_size):
138 yield items[i : i + chunk_size]
141def reduce_in_chunks(fn: Callable[[S, list[T]], S], iterable: list[T], initializer: S, chunk_size: int = 0):
142 """
143 Reduce the given list of items by splitting it into chunks
144 of the given size and passing each chunk through the reducer.
145 """
146 if len(iterable) == 0:
147 return initializer
148 if chunk_size == 0:
149 chunk_size = len(iterable)
150 return reduce(fn, chunks(iterable, chunk_size), initializer)
153def as_flattened_list(iterable: Iterable[Iterable[T]]) -> list[T]:
154 """
155 Return an iterable with one level flattened.
157 >>> as_flattened_list((('blue', 'red'), ('green', 'yellow', 'pink')))
158 ['blue', 'red', 'green', 'yellow', 'pink']
159 """
160 return [e for i in iterable for e in i]
163def parse_template_string(template_string: str) -> tuple[str | None, jinja2.Template | None]:
164 """Parses Jinja template string."""
165 import jinja2
167 if "{{" in template_string: # jinja mode
168 return None, jinja2.Template(template_string)
169 else:
170 return template_string, None
173def render_log_filename(ti: TaskInstance, try_number, filename_template) -> str:
174 """
175 Given task instance, try_number, filename_template, return the rendered log
176 filename.
178 :param ti: task instance
179 :param try_number: try_number of the task
180 :param filename_template: filename template, which can be jinja template or
181 python string template
182 """
183 filename_template, filename_jinja_template = parse_template_string(filename_template)
184 if filename_jinja_template:
185 jinja_context = ti.get_template_context()
186 jinja_context["try_number"] = try_number
187 return render_template_to_string(filename_jinja_template, jinja_context)
189 return filename_template.format(
190 dag_id=ti.dag_id,
191 task_id=ti.task_id,
192 execution_date=ti.execution_date.isoformat(),
193 try_number=try_number,
194 )
197def convert_camel_to_snake(camel_str: str) -> str:
198 """Converts CamelCase to snake_case."""
199 return CAMELCASE_TO_SNAKE_CASE_REGEX.sub(r"_\1", camel_str).lower()
202def merge_dicts(dict1: dict, dict2: dict) -> dict:
203 """
204 Merge two dicts recursively, returning new dict (input dict is not mutated).
206 Lists are not concatenated. Items in dict2 overwrite those also found in dict1.
207 """
208 merged = dict1.copy()
209 for k, v in dict2.items():
210 if k in merged and isinstance(v, dict):
211 merged[k] = merge_dicts(merged.get(k, {}), v)
212 else:
213 merged[k] = v
214 return merged
217def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]:
218 """Use a predicate to partition entries into false entries and true entries."""
219 iter_1, iter_2 = tee(iterable)
220 return filterfalse(pred, iter_1), filter(pred, iter_2)
223def chain(*args, **kwargs):
224 """This function is deprecated. Please use `airflow.models.baseoperator.chain`."""
225 warnings.warn(
226 "This function is deprecated. Please use `airflow.models.baseoperator.chain`.",
227 RemovedInAirflow3Warning,
228 stacklevel=2,
229 )
230 return import_string("airflow.models.baseoperator.chain")(*args, **kwargs)
233def cross_downstream(*args, **kwargs):
234 """This function is deprecated. Please use `airflow.models.baseoperator.cross_downstream`."""
235 warnings.warn(
236 "This function is deprecated. Please use `airflow.models.baseoperator.cross_downstream`.",
237 RemovedInAirflow3Warning,
238 stacklevel=2,
239 )
240 return import_string("airflow.models.baseoperator.cross_downstream")(*args, **kwargs)
243def build_airflow_url_with_query(query: dict[str, Any]) -> str:
244 """
245 Build airflow url using base_url and default_view and provided query.
247 For example:
248 http://0.0.0.0:8000/base/graph?dag_id=my-task&root=&execution_date=2020-10-27T10%3A59%3A25.615587
249 """
250 import flask
252 view = conf.get_mandatory_value("webserver", "dag_default_view").lower()
253 return flask.url_for(f"Airflow.{view}", **query)
256# The 'template' argument is typed as Any because the jinja2.Template is too
257# dynamic to be effectively type-checked.
258def render_template(template: Any, context: MutableMapping[str, Any], *, native: bool) -> Any:
259 """Render a Jinja2 template with given Airflow context.
261 The default implementation of ``jinja2.Template.render()`` converts the
262 input context into dict eagerly many times, which triggers deprecation
263 messages in our custom context class. This takes the implementation apart
264 and retain the context mapping without resolving instead.
266 :param template: A Jinja2 template to render.
267 :param context: The Airflow task context to render the template with.
268 :param native: If set to *True*, render the template into a native type. A
269 DAG can enable this with ``render_template_as_native_obj=True``.
270 :returns: The render result.
271 """
272 context = copy.copy(context)
273 env = template.environment
274 if template.globals:
275 context.update((k, v) for k, v in template.globals.items() if k not in context)
276 try:
277 nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks))
278 except Exception:
279 env.handle_exception() # Rewrite traceback to point to the template.
280 if native:
281 import jinja2.nativetypes
283 return jinja2.nativetypes.native_concat(nodes)
284 return "".join(nodes)
287def render_template_to_string(template: jinja2.Template, context: Context) -> str:
288 """Shorthand to ``render_template(native=False)`` with better typing support."""
289 return render_template(template, cast(MutableMapping[str, Any], context), native=False)
292def render_template_as_native(template: jinja2.Template, context: Context) -> Any:
293 """Shorthand to ``render_template(native=True)`` with better typing support."""
294 return render_template(template, cast(MutableMapping[str, Any], context), native=True)
297def exactly_one(*args) -> bool:
298 """
299 Returns True if exactly one of *args is "truthy", and False otherwise.
301 If user supplies an iterable, we raise ValueError and force them to unpack.
302 """
303 if is_container(args[0]):
304 raise ValueError(
305 "Not supported for iterable args. Use `*` to unpack your iterable in the function call."
306 )
307 return sum(map(bool, args)) == 1
310def at_most_one(*args) -> bool:
311 """
312 Returns True if at most one of *args is "truthy", and False otherwise.
314 NOTSET is treated the same as None.
316 If user supplies an iterable, we raise ValueError and force them to unpack.
317 """
319 def is_set(val):
320 if val is NOTSET:
321 return False
322 else:
323 return bool(val)
325 return sum(map(is_set, args)) in (0, 1)
328def prune_dict(val: Any, mode="strict"):
329 """
330 Given dict ``val``, returns new dict based on ``val`` with all
331 empty elements removed.
333 What constitutes "empty" is controlled by the ``mode`` parameter. If mode is 'strict'
334 then only ``None`` elements will be removed. If mode is ``truthy``, then element ``x``
335 will be removed if ``bool(x) is False``.
336 """
338 def is_empty(x):
339 if mode == "strict":
340 return x is None
341 elif mode == "truthy":
342 return bool(x) is False
343 raise ValueError("allowable values for `mode` include 'truthy' and 'strict'")
345 if isinstance(val, dict):
346 new_dict = {}
347 for k, v in val.items():
348 if is_empty(v):
349 continue
350 elif isinstance(v, (list, dict)):
351 new_val = prune_dict(v, mode=mode)
352 if new_val:
353 new_dict[k] = new_val
354 else:
355 new_dict[k] = v
356 return new_dict
357 elif isinstance(val, list):
358 new_list = []
359 for v in val:
360 if is_empty(v):
361 continue
362 elif isinstance(v, (list, dict)):
363 new_val = prune_dict(v, mode=mode)
364 if new_val:
365 new_list.append(new_val)
366 else:
367 new_list.append(v)
368 return new_list
369 else:
370 return val
373def prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None:
374 """Ensure *kwargs1* and *kwargs2* do not contain common keys.
376 :raises TypeError: If common keys are found.
377 """
378 duplicated_keys = set(kwargs1).intersection(kwargs2)
379 if not duplicated_keys:
380 return
381 if len(duplicated_keys) == 1:
382 raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
383 duplicated_keys_display = ", ".join(sorted(duplicated_keys))
384 raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")