Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/utils/helpers.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

155 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import copy 

21import itertools 

22import re 

23import signal 

24from collections.abc import Callable, Generator, Iterable, MutableMapping 

25from functools import cache 

26from typing import TYPE_CHECKING, Any, TypeVar, overload 

27from urllib.parse import urljoin 

28 

29from lazy_object_proxy import Proxy 

30 

31from airflow.configuration import conf 

32from airflow.exceptions import AirflowException 

33from airflow.serialization.definitions.notset import is_arg_set 

34 

35if TYPE_CHECKING: 

36 from datetime import datetime 

37 from typing import TypeGuard 

38 

39 import jinja2 

40 

41 from airflow.models.taskinstance import TaskInstance 

42 

43 CT = TypeVar("CT", str, datetime) 

44 

45KEY_REGEX = re.compile(r"^[\w.-]+$") 

46GROUP_KEY_REGEX = re.compile(r"^[\w-]+$") 

47CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)") 

48 

49T = TypeVar("T") 

50S = TypeVar("S") 

51 

52 

53def validate_key(k: str, max_length: int = 250): 

54 """Validate value used as a key.""" 

55 if not isinstance(k, str): 

56 raise TypeError(f"The key has to be a string and is {type(k)}:{k}") 

57 if len(k) > max_length: 

58 raise AirflowException(f"The key: {k} has to be less than {max_length} characters") 

59 if not KEY_REGEX.match(k): 

60 raise AirflowException( 

61 f"The key {k!r} has to be made of alphanumeric characters, dashes, " 

62 f"dots and underscores exclusively" 

63 ) 

64 

65 

66def ask_yesno(question: str, default: bool | None = None) -> bool: 

67 """Get a yes or no answer from the user.""" 

68 yes = {"yes", "y"} 

69 no = {"no", "n"} 

70 

71 print(question) 

72 while True: 

73 choice = input().lower() 

74 if choice == "" and default is not None: 

75 return default 

76 if choice in yes: 

77 return True 

78 if choice in no: 

79 return False 

80 print("Please respond with y/yes or n/no.") 

81 

82 

83def prompt_with_timeout(question: str, timeout: int, default: bool | None = None) -> bool: 

84 """Ask the user a question and timeout if they don't respond.""" 

85 

86 def handler(signum, frame): 

87 raise AirflowException(f"Timeout {timeout}s reached") 

88 

89 signal.signal(signal.SIGALRM, handler) 

90 signal.alarm(timeout) 

91 try: 

92 return ask_yesno(question, default) 

93 finally: 

94 signal.alarm(0) 

95 

96 

97@overload 

98def is_container(obj: None | int | Iterable[int] | range) -> TypeGuard[Iterable[int]]: ... 

99 

100 

101@overload 

102def is_container(obj: None | CT | Iterable[CT]) -> TypeGuard[Iterable[CT]]: ... 

103 

104 

105def is_container(obj) -> bool: 

106 """Test if an object is a container (iterable) but not a string.""" 

107 if isinstance(obj, Proxy): 

108 # Proxy of any object is considered a container because it implements __iter__ 

109 # to forward the call to the lazily initialized object 

110 # Unwrap Proxy before checking __iter__ to evaluate the proxied object 

111 obj = obj.__wrapped__ 

112 return hasattr(obj, "__iter__") and not isinstance(obj, str) 

113 

114 

115def chunks(items: list[T], chunk_size: int) -> Generator[list[T], None, None]: 

116 """Yield successive chunks of a given size from a list of items.""" 

117 if chunk_size <= 0: 

118 raise ValueError("Chunk size must be a positive integer") 

119 for i in range(0, len(items), chunk_size): 

120 yield items[i : i + chunk_size] 

121 

122 

123def as_flattened_list(iterable: Iterable[Iterable[T]]) -> list[T]: 

124 """ 

125 Return an iterable with one level flattened. 

126 

127 >>> as_flattened_list((("blue", "red"), ("green", "yellow", "pink"))) 

128 ['blue', 'red', 'green', 'yellow', 'pink'] 

129 """ 

130 return [e for i in iterable for e in i] 

131 

132 

133def parse_template_string(template_string: str) -> tuple[str, None] | tuple[None, jinja2.Template]: 

134 """Parse Jinja template string.""" 

135 import jinja2 

136 

137 if "{{" in template_string: # jinja mode 

138 return None, jinja2.Template(template_string) 

139 return template_string, None 

140 

141 

142@cache 

143def log_filename_template_renderer() -> Callable[..., str]: 

144 template = conf.get("logging", "log_filename_template") 

145 

146 if "{{" in template: 

147 import jinja2 

148 

149 return jinja2.Template(template).render 

150 

151 def f_str_format(ti: TaskInstance, try_number: int | None = None): 

152 return template.format( 

153 dag_id=ti.dag_id, 

154 task_id=ti.task_id, 

155 logical_date=ti.logical_date.isoformat(), 

156 try_number=try_number or ti.try_number, 

157 ) 

158 

159 return f_str_format 

160 

161 

162def convert_camel_to_snake(camel_str: str) -> str: 

163 """Convert CamelCase to snake_case.""" 

164 return CAMELCASE_TO_SNAKE_CASE_REGEX.sub(r"_\1", camel_str).lower() 

165 

166 

167def merge_dicts(dict1: dict, dict2: dict) -> dict: 

168 """ 

169 Merge two dicts recursively, returning new dict (input dict is not mutated). 

170 

171 Lists are not concatenated. Items in dict2 overwrite those also found in dict1. 

172 """ 

173 merged = dict1.copy() 

174 for k, v in dict2.items(): 

175 if k in merged and isinstance(v, dict): 

176 merged[k] = merge_dicts(merged.get(k, {}), v) 

177 else: 

178 merged[k] = v 

179 return merged 

180 

181 

182def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]: 

183 """Use a predicate to partition entries into false entries and true entries.""" 

184 iter_1, iter_2 = itertools.tee(iterable) 

185 return itertools.filterfalse(pred, iter_1), filter(pred, iter_2) 

186 

187 

188def build_airflow_dagrun_url(dag_id: str, run_id: str) -> str: 

189 """ 

190 Build airflow dagrun url using base_url and provided dag_id and run_id. 

191 

192 For example: 

193 http://localhost:8080/dags/hi/runs/manual__2025-02-23T18:27:39.051358+00:00_RZa1at4Q 

194 """ 

195 baseurl = conf.get("api", "base_url", fallback="/") 

196 return urljoin(baseurl.rstrip("/") + "/", f"dags/{dag_id}/runs/{run_id}") 

197 

198 

199# The 'template' argument is typed as Any because the jinja2.Template is too 

200# dynamic to be effectively type-checked. 

201def render_template(template: Any, context: MutableMapping[str, Any], *, native: bool) -> Any: 

202 """ 

203 Render a Jinja2 template with given Airflow context. 

204 

205 The default implementation of ``jinja2.Template.render()`` converts the 

206 input context into dict eagerly many times, which triggers deprecation 

207 messages in our custom context class. This takes the implementation apart 

208 and retain the context mapping without resolving instead. 

209 

210 :param template: A Jinja2 template to render. 

211 :param context: The Airflow task context to render the template with. 

212 :param native: If set to *True*, render the template into a native type. A 

213 DAG can enable this with ``render_template_as_native_obj=True``. 

214 :returns: The render result. 

215 """ 

216 context = copy.copy(context) 

217 env = template.environment 

218 if template.globals: 

219 context.update((k, v) for k, v in template.globals.items() if k not in context) 

220 try: 

221 nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks)) 

222 except Exception: 

223 env.handle_exception() # Rewrite traceback to point to the template. 

224 if native: 

225 import jinja2.nativetypes 

226 

227 return jinja2.nativetypes.native_concat(nodes) 

228 return "".join(nodes) 

229 

230 

231def exactly_one(*args) -> bool: 

232 """ 

233 Return True if exactly one of args is "truthy", and False otherwise. 

234 

235 If user supplies an iterable, we raise ValueError and force them to unpack. 

236 """ 

237 if is_container(args[0]): 

238 raise ValueError( 

239 "Not supported for iterable args. Use `*` to unpack your iterable in the function call." 

240 ) 

241 return sum(map(bool, args)) == 1 

242 

243 

244def at_most_one(*args) -> bool: 

245 """ 

246 Return True if at most one of args is "truthy", and False otherwise. 

247 

248 NOTSET is treated the same as None. 

249 

250 If user supplies an iterable, we raise ValueError and force them to unpack. 

251 """ 

252 return sum(is_arg_set(a) and bool(a) for a in args) in (0, 1) 

253 

254 

255def prune_dict(val: Any, mode="strict"): 

256 """ 

257 Given dict ``val``, returns new dict based on ``val`` with all empty elements removed. 

258 

259 What constitutes "empty" is controlled by the ``mode`` parameter. If mode is 'strict' 

260 then only ``None`` elements will be removed. If mode is ``truthy``, then element ``x`` 

261 will be removed if ``bool(x) is False``. 

262 """ 

263 

264 def is_empty(x): 

265 if mode == "strict": 

266 return x is None 

267 if mode == "truthy": 

268 return bool(x) is False 

269 raise ValueError("allowable values for `mode` include 'truthy' and 'strict'") 

270 

271 if isinstance(val, dict): 

272 new_dict = {} 

273 for k, v in val.items(): 

274 if is_empty(v): 

275 continue 

276 if isinstance(v, (list, dict)): 

277 new_val = prune_dict(v, mode=mode) 

278 if not is_empty(new_val): 

279 new_dict[k] = new_val 

280 else: 

281 new_dict[k] = v 

282 return new_dict 

283 if isinstance(val, list): 

284 new_list = [] 

285 for v in val: 

286 if is_empty(v): 

287 continue 

288 if isinstance(v, (list, dict)): 

289 new_val = prune_dict(v, mode=mode) 

290 if not is_empty(new_val): 

291 new_list.append(new_val) 

292 else: 

293 new_list.append(v) 

294 return new_list 

295 return val 

296 

297 

298__deprecated_imports = { 

299 "render_template_as_native": "airflow.sdk.definitions.context", 

300 "render_template_to_string": "airflow.sdk.definitions.context", 

301 "prevent_duplicates": "airflow.sdk.definitions.mappedoperator", 

302} 

303 

304 

305def __getattr__(name: str): 

306 """Provide backward compatibility for moved functions in this module.""" 

307 try: 

308 modpath = __deprecated_imports[name] 

309 except KeyError: 

310 raise AttributeError(f"module '{__name__}' has no attribute '{name}'") from None 

311 

312 import warnings 

313 

314 warnings.warn( 

315 f"{__name__}.{name} is deprecated. Use {modpath}.{name} instead.", 

316 DeprecationWarning, 

317 stacklevel=2, 

318 ) 

319 return getattr(__import__(modpath), name)