Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/utils/helpers.py: 31%
184 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import copy
21import re
22import signal
23import warnings
24from datetime import datetime
25from functools import reduce
26from itertools import filterfalse, tee
27from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Mapping, MutableMapping, TypeVar, cast
29from airflow.configuration import conf
30from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
31from airflow.utils.context import Context
32from airflow.utils.module_loading import import_string
33from airflow.utils.types import NOTSET
35if TYPE_CHECKING:
36 import jinja2
38 from airflow.models.taskinstance import TaskInstance
40KEY_REGEX = re.compile(r"^[\w.-]+$")
41GROUP_KEY_REGEX = re.compile(r"^[\w-]+$")
42CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)")
44T = TypeVar("T")
45S = TypeVar("S")
48def validate_key(k: str, max_length: int = 250):
49 """Validates value used as a key."""
50 if not isinstance(k, str):
51 raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
52 if len(k) > max_length:
53 raise AirflowException(f"The key has to be less than {max_length} characters")
54 if not KEY_REGEX.match(k):
55 raise AirflowException(
56 f"The key {k!r} has to be made of alphanumeric characters, dashes, "
57 f"dots and underscores exclusively"
58 )
61def validate_group_key(k: str, max_length: int = 200):
62 """Validates value used as a group key."""
63 if not isinstance(k, str):
64 raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
65 if len(k) > max_length:
66 raise AirflowException(f"The key has to be less than {max_length} characters")
67 if not GROUP_KEY_REGEX.match(k):
68 raise AirflowException(
69 f"The key {k!r} has to be made of alphanumeric characters, dashes and underscores exclusively"
70 )
73def alchemy_to_dict(obj: Any) -> dict | None:
74 """Transforms a SQLAlchemy model instance into a dictionary"""
75 if not obj:
76 return None
77 output = {}
78 for col in obj.__table__.columns:
79 value = getattr(obj, col.name)
80 if isinstance(value, datetime):
81 value = value.isoformat()
82 output[col.name] = value
83 return output
86def ask_yesno(question: str, default: bool | None = None) -> bool:
87 """Helper to get a yes or no answer from the user."""
88 yes = {"yes", "y"}
89 no = {"no", "n"}
91 print(question)
92 while True:
93 choice = input().lower()
94 if choice == "" and default is not None:
95 return default
96 if choice in yes:
97 return True
98 if choice in no:
99 return False
100 print("Please respond with y/yes or n/no.")
103def prompt_with_timeout(question: str, timeout: int, default: bool | None = None) -> bool:
104 """Ask the user a question and timeout if they don't respond"""
106 def handler(signum, frame):
107 raise AirflowException(f"Timeout {timeout}s reached")
109 signal.signal(signal.SIGALRM, handler)
110 signal.alarm(timeout)
111 try:
112 return ask_yesno(question, default)
113 finally:
114 signal.alarm(0)
117def is_container(obj: Any) -> bool:
118 """Test if an object is a container (iterable) but not a string"""
119 return hasattr(obj, "__iter__") and not isinstance(obj, str)
122def as_tuple(obj: Any) -> tuple:
123 """
124 If obj is a container, returns obj as a tuple.
125 Otherwise, returns a tuple containing obj.
126 """
127 if is_container(obj):
128 return tuple(obj)
129 else:
130 return tuple([obj])
133def chunks(items: list[T], chunk_size: int) -> Generator[list[T], None, None]:
134 """Yield successive chunks of a given size from a list of items"""
135 if chunk_size <= 0:
136 raise ValueError("Chunk size must be a positive integer")
137 for i in range(0, len(items), chunk_size):
138 yield items[i : i + chunk_size]
141def reduce_in_chunks(fn: Callable[[S, list[T]], S], iterable: list[T], initializer: S, chunk_size: int = 0):
142 """
143 Reduce the given list of items by splitting it into chunks
144 of the given size and passing each chunk through the reducer
145 """
146 if len(iterable) == 0:
147 return initializer
148 if chunk_size == 0:
149 chunk_size = len(iterable)
150 return reduce(fn, chunks(iterable, chunk_size), initializer)
153def as_flattened_list(iterable: Iterable[Iterable[T]]) -> list[T]:
154 """
155 Return an iterable with one level flattened
157 >>> as_flattened_list((('blue', 'red'), ('green', 'yellow', 'pink')))
158 ['blue', 'red', 'green', 'yellow', 'pink']
159 """
160 return [e for i in iterable for e in i]
163def parse_template_string(template_string: str) -> tuple[str | None, jinja2.Template | None]:
164 """Parses Jinja template string."""
165 import jinja2
167 if "{{" in template_string: # jinja mode
168 return None, jinja2.Template(template_string)
169 else:
170 return template_string, None
173def render_log_filename(ti: TaskInstance, try_number, filename_template) -> str:
174 """
175 Given task instance, try_number, filename_template, return the rendered log
176 filename
178 :param ti: task instance
179 :param try_number: try_number of the task
180 :param filename_template: filename template, which can be jinja template or
181 python string template
182 """
183 filename_template, filename_jinja_template = parse_template_string(filename_template)
184 if filename_jinja_template:
185 jinja_context = ti.get_template_context()
186 jinja_context["try_number"] = try_number
187 return render_template_to_string(filename_jinja_template, jinja_context)
189 return filename_template.format(
190 dag_id=ti.dag_id,
191 task_id=ti.task_id,
192 execution_date=ti.execution_date.isoformat(),
193 try_number=try_number,
194 )
197def convert_camel_to_snake(camel_str: str) -> str:
198 """Converts CamelCase to snake_case."""
199 return CAMELCASE_TO_SNAKE_CASE_REGEX.sub(r"_\1", camel_str).lower()
202def merge_dicts(dict1: dict, dict2: dict) -> dict:
203 """
204 Merge two dicts recursively, returning new dict (input dict is not mutated).
206 Lists are not concatenated. Items in dict2 overwrite those also found in dict1.
207 """
208 merged = dict1.copy()
209 for k, v in dict2.items():
210 if k in merged and isinstance(v, dict):
211 merged[k] = merge_dicts(merged.get(k, {}), v)
212 else:
213 merged[k] = v
214 return merged
217def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]:
218 """Use a predicate to partition entries into false entries and true entries"""
219 iter_1, iter_2 = tee(iterable)
220 return filterfalse(pred, iter_1), filter(pred, iter_2)
223def chain(*args, **kwargs):
224 """This function is deprecated. Please use `airflow.models.baseoperator.chain`."""
225 warnings.warn(
226 "This function is deprecated. Please use `airflow.models.baseoperator.chain`.",
227 RemovedInAirflow3Warning,
228 stacklevel=2,
229 )
230 return import_string("airflow.models.baseoperator.chain")(*args, **kwargs)
233def cross_downstream(*args, **kwargs):
234 """This function is deprecated. Please use `airflow.models.baseoperator.cross_downstream`."""
235 warnings.warn(
236 "This function is deprecated. Please use `airflow.models.baseoperator.cross_downstream`.",
237 RemovedInAirflow3Warning,
238 stacklevel=2,
239 )
240 return import_string("airflow.models.baseoperator.cross_downstream")(*args, **kwargs)
243def build_airflow_url_with_query(query: dict[str, Any]) -> str:
244 """
245 Build airflow url using base_url and default_view and provided query
246 For example:
247 'http://0.0.0.0:8000/base/graph?dag_id=my-task&root=&execution_date=2020-10-27T10%3A59%3A25.615587
248 """
249 import flask
251 view = conf.get_mandatory_value("webserver", "dag_default_view").lower()
252 return flask.url_for(f"Airflow.{view}", **query)
255# The 'template' argument is typed as Any because the jinja2.Template is too
256# dynamic to be effectively type-checked.
257def render_template(template: Any, context: MutableMapping[str, Any], *, native: bool) -> Any:
258 """Render a Jinja2 template with given Airflow context.
260 The default implementation of ``jinja2.Template.render()`` converts the
261 input context into dict eagerly many times, which triggers deprecation
262 messages in our custom context class. This takes the implementation apart
263 and retain the context mapping without resolving instead.
265 :param template: A Jinja2 template to render.
266 :param context: The Airflow task context to render the template with.
267 :param native: If set to *True*, render the template into a native type. A
268 DAG can enable this with ``render_template_as_native_obj=True``.
269 :returns: The render result.
270 """
271 context = copy.copy(context)
272 env = template.environment
273 if template.globals:
274 context.update((k, v) for k, v in template.globals.items() if k not in context)
275 try:
276 nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks))
277 except Exception:
278 env.handle_exception() # Rewrite traceback to point to the template.
279 if native:
280 import jinja2.nativetypes
282 return jinja2.nativetypes.native_concat(nodes)
283 return "".join(nodes)
286def render_template_to_string(template: jinja2.Template, context: Context) -> str:
287 """Shorthand to ``render_template(native=False)`` with better typing support."""
288 return render_template(template, cast(MutableMapping[str, Any], context), native=False)
291def render_template_as_native(template: jinja2.Template, context: Context) -> Any:
292 """Shorthand to ``render_template(native=True)`` with better typing support."""
293 return render_template(template, cast(MutableMapping[str, Any], context), native=True)
296def exactly_one(*args) -> bool:
297 """
298 Returns True if exactly one of *args is "truthy", and False otherwise.
300 If user supplies an iterable, we raise ValueError and force them to unpack.
301 """
302 if is_container(args[0]):
303 raise ValueError(
304 "Not supported for iterable args. Use `*` to unpack your iterable in the function call."
305 )
306 return sum(map(bool, args)) == 1
309def at_most_one(*args) -> bool:
310 """
311 Returns True if at most one of *args is "truthy", and False otherwise.
313 NOTSET is treated the same as None.
315 If user supplies an iterable, we raise ValueError and force them to unpack.
316 """
318 def is_set(val):
319 if val is NOTSET:
320 return False
321 else:
322 return bool(val)
324 return sum(map(is_set, args)) in (0, 1)
327def prune_dict(val: Any, mode="strict"):
328 """
329 Given dict ``val``, returns new dict based on ``val`` with all
330 empty elements removed.
332 What constitutes "empty" is controlled by the ``mode`` parameter. If mode is 'strict'
333 then only ``None`` elements will be removed. If mode is ``truthy``, then element ``x``
334 will be removed if ``bool(x) is False``.
335 """
337 def is_empty(x):
338 if mode == "strict":
339 return x is None
340 elif mode == "truthy":
341 return bool(x) is False
342 raise ValueError("allowable values for `mode` include 'truthy' and 'strict'")
344 if isinstance(val, dict):
345 new_dict = {}
346 for k, v in val.items():
347 if is_empty(v):
348 continue
349 elif isinstance(v, (list, dict)):
350 new_val = prune_dict(v, mode=mode)
351 if new_val:
352 new_dict[k] = new_val
353 else:
354 new_dict[k] = v
355 return new_dict
356 elif isinstance(val, list):
357 new_list = []
358 for v in val:
359 if is_empty(v):
360 continue
361 elif isinstance(v, (list, dict)):
362 new_val = prune_dict(v, mode=mode)
363 if new_val:
364 new_list.append(new_val)
365 else:
366 new_list.append(v)
367 return new_list
368 else:
369 return val
372def prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None:
373 """Ensure *kwargs1* and *kwargs2* do not contain common keys.
375 :raises TypeError: If common keys are found.
376 """
377 duplicated_keys = set(kwargs1).intersection(kwargs2)
378 if not duplicated_keys:
379 return
380 if len(duplicated_keys) == 1:
381 raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
382 duplicated_keys_display = ", ".join(sorted(duplicated_keys))
383 raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")