Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/utils/helpers.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

187 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import copy 

21import itertools 

22import re 

23import signal 

24import warnings 

25from datetime import datetime 

26from functools import reduce 

27from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Mapping, MutableMapping, TypeVar, cast 

28 

29from lazy_object_proxy import Proxy 

30 

31from airflow.configuration import conf 

32from airflow.exceptions import AirflowException, RemovedInAirflow3Warning 

33from airflow.utils.module_loading import import_string 

34from airflow.utils.types import NOTSET 

35 

36if TYPE_CHECKING: 

37 import jinja2 

38 

39 from airflow.models.taskinstance import TaskInstance 

40 from airflow.utils.context import Context 

41 

42KEY_REGEX = re.compile(r"^[\w.-]+$") 

43GROUP_KEY_REGEX = re.compile(r"^[\w-]+$") 

44CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)") 

45 

46T = TypeVar("T") 

47S = TypeVar("S") 

48 

49 

50def validate_key(k: str, max_length: int = 250): 

51 """Validate value used as a key.""" 

52 if not isinstance(k, str): 

53 raise TypeError(f"The key has to be a string and is {type(k)}:{k}") 

54 if len(k) > max_length: 

55 raise AirflowException(f"The key has to be less than {max_length} characters") 

56 if not KEY_REGEX.match(k): 

57 raise AirflowException( 

58 f"The key {k!r} has to be made of alphanumeric characters, dashes, " 

59 f"dots and underscores exclusively" 

60 ) 

61 

62 

63def validate_group_key(k: str, max_length: int = 200): 

64 """Validate value used as a group key.""" 

65 if not isinstance(k, str): 

66 raise TypeError(f"The key has to be a string and is {type(k)}:{k}") 

67 if len(k) > max_length: 

68 raise AirflowException(f"The key has to be less than {max_length} characters") 

69 if not GROUP_KEY_REGEX.match(k): 

70 raise AirflowException( 

71 f"The key {k!r} has to be made of alphanumeric characters, dashes and underscores exclusively" 

72 ) 

73 

74 

75def alchemy_to_dict(obj: Any) -> dict | None: 

76 """Transform a SQLAlchemy model instance into a dictionary.""" 

77 if not obj: 

78 return None 

79 output = {} 

80 for col in obj.__table__.columns: 

81 value = getattr(obj, col.name) 

82 if isinstance(value, datetime): 

83 value = value.isoformat() 

84 output[col.name] = value 

85 return output 

86 

87 

88def ask_yesno(question: str, default: bool | None = None) -> bool: 

89 """Get a yes or no answer from the user.""" 

90 yes = {"yes", "y"} 

91 no = {"no", "n"} 

92 

93 print(question) 

94 while True: 

95 choice = input().lower() 

96 if choice == "" and default is not None: 

97 return default 

98 if choice in yes: 

99 return True 

100 if choice in no: 

101 return False 

102 print("Please respond with y/yes or n/no.") 

103 

104 

105def prompt_with_timeout(question: str, timeout: int, default: bool | None = None) -> bool: 

106 """Ask the user a question and timeout if they don't respond.""" 

107 

108 def handler(signum, frame): 

109 raise AirflowException(f"Timeout {timeout}s reached") 

110 

111 signal.signal(signal.SIGALRM, handler) 

112 signal.alarm(timeout) 

113 try: 

114 return ask_yesno(question, default) 

115 finally: 

116 signal.alarm(0) 

117 

118 

119def is_container(obj: Any) -> bool: 

120 """Test if an object is a container (iterable) but not a string.""" 

121 if isinstance(obj, Proxy): 

122 # Proxy of any object is considered a container because it implements __iter__ 

123 # to forward the call to the lazily initialized object 

124 # Unwrap Proxy before checking __iter__ to evaluate the proxied object 

125 obj = obj.__wrapped__ 

126 return hasattr(obj, "__iter__") and not isinstance(obj, str) 

127 

128 

129def as_tuple(obj: Any) -> tuple: 

130 """Return obj as a tuple if obj is a container, otherwise return a tuple containing obj.""" 

131 if is_container(obj): 

132 return tuple(obj) 

133 else: 

134 return tuple([obj]) 

135 

136 

137def chunks(items: list[T], chunk_size: int) -> Generator[list[T], None, None]: 

138 """Yield successive chunks of a given size from a list of items.""" 

139 if chunk_size <= 0: 

140 raise ValueError("Chunk size must be a positive integer") 

141 for i in range(0, len(items), chunk_size): 

142 yield items[i : i + chunk_size] 

143 

144 

145def reduce_in_chunks(fn: Callable[[S, list[T]], S], iterable: list[T], initializer: S, chunk_size: int = 0): 

146 """Split the list of items into chunks of a given size and pass each chunk through the reducer.""" 

147 if not iterable: 

148 return initializer 

149 if chunk_size == 0: 

150 chunk_size = len(iterable) 

151 return reduce(fn, chunks(iterable, chunk_size), initializer) 

152 

153 

154def as_flattened_list(iterable: Iterable[Iterable[T]]) -> list[T]: 

155 """ 

156 Return an iterable with one level flattened. 

157 

158 >>> as_flattened_list((("blue", "red"), ("green", "yellow", "pink"))) 

159 ['blue', 'red', 'green', 'yellow', 'pink'] 

160 """ 

161 return [e for i in iterable for e in i] 

162 

163 

164def parse_template_string(template_string: str) -> tuple[str | None, jinja2.Template | None]: 

165 """Parse Jinja template string.""" 

166 import jinja2 

167 

168 if "{{" in template_string: # jinja mode 

169 return None, jinja2.Template(template_string) 

170 else: 

171 return template_string, None 

172 

173 

174def render_log_filename(ti: TaskInstance, try_number, filename_template) -> str: 

175 """ 

176 Given task instance, try_number, filename_template, return the rendered log filename. 

177 

178 :param ti: task instance 

179 :param try_number: try_number of the task 

180 :param filename_template: filename template, which can be jinja template or 

181 python string template 

182 """ 

183 filename_template, filename_jinja_template = parse_template_string(filename_template) 

184 if filename_jinja_template: 

185 jinja_context = ti.get_template_context() 

186 jinja_context["try_number"] = try_number 

187 return render_template_to_string(filename_jinja_template, jinja_context) 

188 

189 return filename_template.format( 

190 dag_id=ti.dag_id, 

191 task_id=ti.task_id, 

192 execution_date=ti.execution_date.isoformat(), 

193 try_number=try_number, 

194 ) 

195 

196 

197def convert_camel_to_snake(camel_str: str) -> str: 

198 """Convert CamelCase to snake_case.""" 

199 return CAMELCASE_TO_SNAKE_CASE_REGEX.sub(r"_\1", camel_str).lower() 

200 

201 

202def merge_dicts(dict1: dict, dict2: dict) -> dict: 

203 """ 

204 Merge two dicts recursively, returning new dict (input dict is not mutated). 

205 

206 Lists are not concatenated. Items in dict2 overwrite those also found in dict1. 

207 """ 

208 merged = dict1.copy() 

209 for k, v in dict2.items(): 

210 if k in merged and isinstance(v, dict): 

211 merged[k] = merge_dicts(merged.get(k, {}), v) 

212 else: 

213 merged[k] = v 

214 return merged 

215 

216 

217def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]: 

218 """Use a predicate to partition entries into false entries and true entries.""" 

219 iter_1, iter_2 = itertools.tee(iterable) 

220 return itertools.filterfalse(pred, iter_1), filter(pred, iter_2) 

221 

222 

223def chain(*args, **kwargs): 

224 """Use `airflow.models.baseoperator.chain`, this function is deprecated.""" 

225 warnings.warn( 

226 "This function is deprecated. Please use `airflow.models.baseoperator.chain`.", 

227 RemovedInAirflow3Warning, 

228 stacklevel=2, 

229 ) 

230 return import_string("airflow.models.baseoperator.chain")(*args, **kwargs) 

231 

232 

233def cross_downstream(*args, **kwargs): 

234 """Use `airflow.models.baseoperator.cross_downstream`, this function is deprecated.""" 

235 warnings.warn( 

236 "This function is deprecated. Please use `airflow.models.baseoperator.cross_downstream`.", 

237 RemovedInAirflow3Warning, 

238 stacklevel=2, 

239 ) 

240 return import_string("airflow.models.baseoperator.cross_downstream")(*args, **kwargs) 

241 

242 

243def build_airflow_url_with_query(query: dict[str, Any]) -> str: 

244 """ 

245 Build airflow url using base_url and default_view and provided query. 

246 

247 For example: 

248 http://0.0.0.0:8000/base/graph?dag_id=my-task&root=&execution_date=2020-10-27T10%3A59%3A25.615587 

249 """ 

250 import flask 

251 

252 view = conf.get_mandatory_value("webserver", "dag_default_view").lower() 

253 return flask.url_for(f"Airflow.{view}", **query) 

254 

255 

256# The 'template' argument is typed as Any because the jinja2.Template is too 

257# dynamic to be effectively type-checked. 

258def render_template(template: Any, context: MutableMapping[str, Any], *, native: bool) -> Any: 

259 """Render a Jinja2 template with given Airflow context. 

260 

261 The default implementation of ``jinja2.Template.render()`` converts the 

262 input context into dict eagerly many times, which triggers deprecation 

263 messages in our custom context class. This takes the implementation apart 

264 and retain the context mapping without resolving instead. 

265 

266 :param template: A Jinja2 template to render. 

267 :param context: The Airflow task context to render the template with. 

268 :param native: If set to *True*, render the template into a native type. A 

269 DAG can enable this with ``render_template_as_native_obj=True``. 

270 :returns: The render result. 

271 """ 

272 context = copy.copy(context) 

273 env = template.environment 

274 if template.globals: 

275 context.update((k, v) for k, v in template.globals.items() if k not in context) 

276 try: 

277 nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks)) 

278 except Exception: 

279 env.handle_exception() # Rewrite traceback to point to the template. 

280 if native: 

281 import jinja2.nativetypes 

282 

283 return jinja2.nativetypes.native_concat(nodes) 

284 return "".join(nodes) 

285 

286 

287def render_template_to_string(template: jinja2.Template, context: Context) -> str: 

288 """Shorthand to ``render_template(native=False)`` with better typing support.""" 

289 return render_template(template, cast(MutableMapping[str, Any], context), native=False) 

290 

291 

292def render_template_as_native(template: jinja2.Template, context: Context) -> Any: 

293 """Shorthand to ``render_template(native=True)`` with better typing support.""" 

294 return render_template(template, cast(MutableMapping[str, Any], context), native=True) 

295 

296 

297def exactly_one(*args) -> bool: 

298 """ 

299 Return True if exactly one of *args is "truthy", and False otherwise. 

300 

301 If user supplies an iterable, we raise ValueError and force them to unpack. 

302 """ 

303 if is_container(args[0]): 

304 raise ValueError( 

305 "Not supported for iterable args. Use `*` to unpack your iterable in the function call." 

306 ) 

307 return sum(map(bool, args)) == 1 

308 

309 

310def at_most_one(*args) -> bool: 

311 """ 

312 Return True if at most one of *args is "truthy", and False otherwise. 

313 

314 NOTSET is treated the same as None. 

315 

316 If user supplies an iterable, we raise ValueError and force them to unpack. 

317 """ 

318 

319 def is_set(val): 

320 if val is NOTSET: 

321 return False 

322 else: 

323 return bool(val) 

324 

325 return sum(map(is_set, args)) in (0, 1) 

326 

327 

328def prune_dict(val: Any, mode="strict"): 

329 """ 

330 Given dict ``val``, returns new dict based on ``val`` with all empty elements removed. 

331 

332 What constitutes "empty" is controlled by the ``mode`` parameter. If mode is 'strict' 

333 then only ``None`` elements will be removed. If mode is ``truthy``, then element ``x`` 

334 will be removed if ``bool(x) is False``. 

335 """ 

336 

337 def is_empty(x): 

338 if mode == "strict": 

339 return x is None 

340 elif mode == "truthy": 

341 return bool(x) is False 

342 raise ValueError("allowable values for `mode` include 'truthy' and 'strict'") 

343 

344 if isinstance(val, dict): 

345 new_dict = {} 

346 for k, v in val.items(): 

347 if is_empty(v): 

348 continue 

349 elif isinstance(v, (list, dict)): 

350 new_val = prune_dict(v, mode=mode) 

351 if not is_empty(new_val): 

352 new_dict[k] = new_val 

353 else: 

354 new_dict[k] = v 

355 return new_dict 

356 elif isinstance(val, list): 

357 new_list = [] 

358 for v in val: 

359 if is_empty(v): 

360 continue 

361 elif isinstance(v, (list, dict)): 

362 new_val = prune_dict(v, mode=mode) 

363 if not is_empty(new_val): 

364 new_list.append(new_val) 

365 else: 

366 new_list.append(v) 

367 return new_list 

368 else: 

369 return val 

370 

371 

372def prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None: 

373 """Ensure *kwargs1* and *kwargs2* do not contain common keys. 

374 

375 :raises TypeError: If common keys are found. 

376 """ 

377 duplicated_keys = set(kwargs1).intersection(kwargs2) 

378 if not duplicated_keys: 

379 return 

380 if len(duplicated_keys) == 1: 

381 raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") 

382 duplicated_keys_display = ", ".join(sorted(duplicated_keys)) 

383 raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")