Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/utils/helpers.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

173 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import copy 

21import itertools 

22import re 

23import signal 

24from collections.abc import Callable, Generator, Iterable, MutableMapping 

25from functools import cache 

26from typing import TYPE_CHECKING, Any, TypeVar, cast, overload 

27from urllib.parse import urljoin 

28 

29from lazy_object_proxy import Proxy 

30 

31from airflow.configuration import conf 

32from airflow.exceptions import AirflowException 

33from airflow.serialization.definitions.notset import is_arg_set 

34 

35if TYPE_CHECKING: 

36 from datetime import datetime 

37 from typing import TypeGuard 

38 

39 import jinja2 

40 

41 from airflow.models.taskinstance import TaskInstance 

42 from airflow.sdk.definitions.context import Context 

43 

44 CT = TypeVar("CT", str, datetime) 

45 

46KEY_REGEX = re.compile(r"^[\w.-]+$") 

47GROUP_KEY_REGEX = re.compile(r"^[\w-]+$") 

48CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)") 

49 

50T = TypeVar("T") 

51S = TypeVar("S") 

52 

53 

54def validate_key(k: str, max_length: int = 250): 

55 """Validate value used as a key.""" 

56 if not isinstance(k, str): 

57 raise TypeError(f"The key has to be a string and is {type(k)}:{k}") 

58 if len(k) > max_length: 

59 raise AirflowException(f"The key: {k} has to be less than {max_length} characters") 

60 if not KEY_REGEX.match(k): 

61 raise AirflowException( 

62 f"The key {k!r} has to be made of alphanumeric characters, dashes, " 

63 f"dots and underscores exclusively" 

64 ) 

65 

66 

67def ask_yesno(question: str, default: bool | None = None) -> bool: 

68 """Get a yes or no answer from the user.""" 

69 yes = {"yes", "y"} 

70 no = {"no", "n"} 

71 

72 print(question) 

73 while True: 

74 choice = input().lower() 

75 if choice == "" and default is not None: 

76 return default 

77 if choice in yes: 

78 return True 

79 if choice in no: 

80 return False 

81 print("Please respond with y/yes or n/no.") 

82 

83 

84def prompt_with_timeout(question: str, timeout: int, default: bool | None = None) -> bool: 

85 """Ask the user a question and timeout if they don't respond.""" 

86 

87 def handler(signum, frame): 

88 raise AirflowException(f"Timeout {timeout}s reached") 

89 

90 signal.signal(signal.SIGALRM, handler) 

91 signal.alarm(timeout) 

92 try: 

93 return ask_yesno(question, default) 

94 finally: 

95 signal.alarm(0) 

96 

97 

98@overload 

99def is_container(obj: None | int | Iterable[int] | range) -> TypeGuard[Iterable[int]]: ... 

100 

101 

102@overload 

103def is_container(obj: None | CT | Iterable[CT]) -> TypeGuard[Iterable[CT]]: ... 

104 

105 

106def is_container(obj) -> bool: 

107 """Test if an object is a container (iterable) but not a string.""" 

108 if isinstance(obj, Proxy): 

109 # Proxy of any object is considered a container because it implements __iter__ 

110 # to forward the call to the lazily initialized object 

111 # Unwrap Proxy before checking __iter__ to evaluate the proxied object 

112 obj = obj.__wrapped__ 

113 return hasattr(obj, "__iter__") and not isinstance(obj, str) 

114 

115 

116def chunks(items: list[T], chunk_size: int) -> Generator[list[T], None, None]: 

117 """Yield successive chunks of a given size from a list of items.""" 

118 if chunk_size <= 0: 

119 raise ValueError("Chunk size must be a positive integer") 

120 for i in range(0, len(items), chunk_size): 

121 yield items[i : i + chunk_size] 

122 

123 

124def as_flattened_list(iterable: Iterable[Iterable[T]]) -> list[T]: 

125 """ 

126 Return an iterable with one level flattened. 

127 

128 >>> as_flattened_list((("blue", "red"), ("green", "yellow", "pink"))) 

129 ['blue', 'red', 'green', 'yellow', 'pink'] 

130 """ 

131 return [e for i in iterable for e in i] 

132 

133 

134def parse_template_string(template_string: str) -> tuple[str, None] | tuple[None, jinja2.Template]: 

135 """Parse Jinja template string.""" 

136 import jinja2 

137 

138 if "{{" in template_string: # jinja mode 

139 return None, jinja2.Template(template_string) 

140 return template_string, None 

141 

142 

143@cache 

144def log_filename_template_renderer() -> Callable[..., str]: 

145 template = conf.get("logging", "log_filename_template") 

146 

147 if "{{" in template: 

148 import jinja2 

149 

150 return jinja2.Template(template).render 

151 

152 def f_str_format(ti: TaskInstance, try_number: int | None = None): 

153 return template.format( 

154 dag_id=ti.dag_id, 

155 task_id=ti.task_id, 

156 logical_date=ti.logical_date.isoformat(), 

157 try_number=try_number or ti.try_number, 

158 ) 

159 

160 return f_str_format 

161 

162 

163def _render_template_to_string(template: jinja2.Template, context: Context) -> str: 

164 """ 

165 Render a Jinja template to string using the provided context. 

166 

167 This is a private utility function specifically for log filename rendering. 

168 It ensures templates are rendered as strings rather than native Python objects. 

169 """ 

170 return render_template(template, cast("MutableMapping[str, Any]", context), native=False) 

171 

172 

173def render_log_filename(ti: TaskInstance, try_number, filename_template) -> str: 

174 """ 

175 Given task instance, try_number, filename_template, return the rendered log filename. 

176 

177 :param ti: task instance 

178 :param try_number: try_number of the task 

179 :param filename_template: filename template, which can be jinja template or 

180 python string template 

181 """ 

182 filename_template, filename_jinja_template = parse_template_string(filename_template) 

183 if filename_jinja_template: 

184 jinja_context = ti.get_template_context() 

185 jinja_context["try_number"] = try_number 

186 return _render_template_to_string(filename_jinja_template, jinja_context) 

187 

188 return filename_template.format( 

189 dag_id=ti.dag_id, 

190 task_id=ti.task_id, 

191 logical_date=ti.logical_date.isoformat(), 

192 try_number=try_number, 

193 ) 

194 

195 

196def convert_camel_to_snake(camel_str: str) -> str: 

197 """Convert CamelCase to snake_case.""" 

198 return CAMELCASE_TO_SNAKE_CASE_REGEX.sub(r"_\1", camel_str).lower() 

199 

200 

201def merge_dicts(dict1: dict, dict2: dict) -> dict: 

202 """ 

203 Merge two dicts recursively, returning new dict (input dict is not mutated). 

204 

205 Lists are not concatenated. Items in dict2 overwrite those also found in dict1. 

206 """ 

207 merged = dict1.copy() 

208 for k, v in dict2.items(): 

209 if k in merged and isinstance(v, dict): 

210 merged[k] = merge_dicts(merged.get(k, {}), v) 

211 else: 

212 merged[k] = v 

213 return merged 

214 

215 

216def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]: 

217 """Use a predicate to partition entries into false entries and true entries.""" 

218 iter_1, iter_2 = itertools.tee(iterable) 

219 return itertools.filterfalse(pred, iter_1), filter(pred, iter_2) 

220 

221 

222def build_airflow_dagrun_url(dag_id: str, run_id: str) -> str: 

223 """ 

224 Build airflow dagrun url using base_url and provided dag_id and run_id. 

225 

226 For example: 

227 http://localhost:8080/dags/hi/runs/manual__2025-02-23T18:27:39.051358+00:00_RZa1at4Q 

228 """ 

229 baseurl = conf.get("api", "base_url", fallback="/") 

230 return urljoin(baseurl.rstrip("/") + "/", f"dags/{dag_id}/runs/{run_id}") 

231 

232 

233# The 'template' argument is typed as Any because the jinja2.Template is too 

234# dynamic to be effectively type-checked. 

235def render_template(template: Any, context: MutableMapping[str, Any], *, native: bool) -> Any: 

236 """ 

237 Render a Jinja2 template with given Airflow context. 

238 

239 The default implementation of ``jinja2.Template.render()`` converts the 

240 input context into dict eagerly many times, which triggers deprecation 

241 messages in our custom context class. This takes the implementation apart 

242 and retain the context mapping without resolving instead. 

243 

244 :param template: A Jinja2 template to render. 

245 :param context: The Airflow task context to render the template with. 

246 :param native: If set to *True*, render the template into a native type. A 

247 DAG can enable this with ``render_template_as_native_obj=True``. 

248 :returns: The render result. 

249 """ 

250 context = copy.copy(context) 

251 env = template.environment 

252 if template.globals: 

253 context.update((k, v) for k, v in template.globals.items() if k not in context) 

254 try: 

255 nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks)) 

256 except Exception: 

257 env.handle_exception() # Rewrite traceback to point to the template. 

258 if native: 

259 import jinja2.nativetypes 

260 

261 return jinja2.nativetypes.native_concat(nodes) 

262 return "".join(nodes) 

263 

264 

265def exactly_one(*args) -> bool: 

266 """ 

267 Return True if exactly one of args is "truthy", and False otherwise. 

268 

269 If user supplies an iterable, we raise ValueError and force them to unpack. 

270 """ 

271 if is_container(args[0]): 

272 raise ValueError( 

273 "Not supported for iterable args. Use `*` to unpack your iterable in the function call." 

274 ) 

275 return sum(map(bool, args)) == 1 

276 

277 

278def at_most_one(*args) -> bool: 

279 """ 

280 Return True if at most one of args is "truthy", and False otherwise. 

281 

282 NOTSET is treated the same as None. 

283 

284 If user supplies an iterable, we raise ValueError and force them to unpack. 

285 """ 

286 return sum(is_arg_set(a) and bool(a) for a in args) in (0, 1) 

287 

288 

289def prune_dict(val: Any, mode="strict"): 

290 """ 

291 Given dict ``val``, returns new dict based on ``val`` with all empty elements removed. 

292 

293 What constitutes "empty" is controlled by the ``mode`` parameter. If mode is 'strict' 

294 then only ``None`` elements will be removed. If mode is ``truthy``, then element ``x`` 

295 will be removed if ``bool(x) is False``. 

296 """ 

297 

298 def is_empty(x): 

299 if mode == "strict": 

300 return x is None 

301 if mode == "truthy": 

302 return bool(x) is False 

303 raise ValueError("allowable values for `mode` include 'truthy' and 'strict'") 

304 

305 if isinstance(val, dict): 

306 new_dict = {} 

307 for k, v in val.items(): 

308 if is_empty(v): 

309 continue 

310 if isinstance(v, (list, dict)): 

311 new_val = prune_dict(v, mode=mode) 

312 if not is_empty(new_val): 

313 new_dict[k] = new_val 

314 else: 

315 new_dict[k] = v 

316 return new_dict 

317 if isinstance(val, list): 

318 new_list = [] 

319 for v in val: 

320 if is_empty(v): 

321 continue 

322 if isinstance(v, (list, dict)): 

323 new_val = prune_dict(v, mode=mode) 

324 if not is_empty(new_val): 

325 new_list.append(new_val) 

326 else: 

327 new_list.append(v) 

328 return new_list 

329 return val 

330 

331 

332def __getattr__(name: str): 

333 """Provide backward compatibility for moved functions in this module.""" 

334 if name == "render_template_as_native": 

335 import warnings 

336 

337 from airflow.sdk.definitions.context import render_template_as_native 

338 

339 warnings.warn( 

340 "airflow.utils.helpers.render_template_as_native is deprecated. " 

341 "Use airflow.sdk.definitions.context.render_template_as_native instead.", 

342 DeprecationWarning, 

343 stacklevel=2, 

344 ) 

345 return render_template_as_native 

346 

347 if name == "prevent_duplicates": 

348 import warnings 

349 

350 from airflow.sdk.definitions.mappedoperator import prevent_duplicates 

351 

352 warnings.warn( 

353 "airflow.utils.helpers.prevent_duplicates is deprecated. " 

354 "Use airflow.sdk.definitions.mappedoperator.prevent_duplicates instead.", 

355 DeprecationWarning, 

356 stacklevel=2, 

357 ) 

358 return prevent_duplicates 

359 

360 if name == "render_template_to_string": 

361 import warnings 

362 

363 from airflow.sdk.definitions.context import render_template_as_native 

364 

365 warnings.warn( 

366 "airflow.utils.helpers.render_template_to_string is deprecated. " 

367 "Use airflow.sdk.definitions.context.render_template_to_string instead.", 

368 DeprecationWarning, 

369 stacklevel=2, 

370 ) 

371 return render_template_as_native 

372 

373 raise AttributeError(f"module '{__name__}' has no attribute '{name}'")