Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/utils/helpers.py: 31%

184 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import copy 

21import re 

22import signal 

23import warnings 

24from datetime import datetime 

25from functools import reduce 

26from itertools import filterfalse, tee 

27from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Mapping, MutableMapping, TypeVar, cast 

28 

29from airflow.configuration import conf 

30from airflow.exceptions import AirflowException, RemovedInAirflow3Warning 

31from airflow.utils.context import Context 

32from airflow.utils.module_loading import import_string 

33from airflow.utils.types import NOTSET 

34 

35if TYPE_CHECKING: 

36 import jinja2 

37 

38 from airflow.models.taskinstance import TaskInstance 

39 

40KEY_REGEX = re.compile(r"^[\w.-]+$") 

41GROUP_KEY_REGEX = re.compile(r"^[\w-]+$") 

42CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)") 

43 

44T = TypeVar("T") 

45S = TypeVar("S") 

46 

47 

48def validate_key(k: str, max_length: int = 250): 

49 """Validates value used as a key.""" 

50 if not isinstance(k, str): 

51 raise TypeError(f"The key has to be a string and is {type(k)}:{k}") 

52 if len(k) > max_length: 

53 raise AirflowException(f"The key has to be less than {max_length} characters") 

54 if not KEY_REGEX.match(k): 

55 raise AirflowException( 

56 f"The key {k!r} has to be made of alphanumeric characters, dashes, " 

57 f"dots and underscores exclusively" 

58 ) 

59 

60 

61def validate_group_key(k: str, max_length: int = 200): 

62 """Validates value used as a group key.""" 

63 if not isinstance(k, str): 

64 raise TypeError(f"The key has to be a string and is {type(k)}:{k}") 

65 if len(k) > max_length: 

66 raise AirflowException(f"The key has to be less than {max_length} characters") 

67 if not GROUP_KEY_REGEX.match(k): 

68 raise AirflowException( 

69 f"The key {k!r} has to be made of alphanumeric characters, dashes and underscores exclusively" 

70 ) 

71 

72 

73def alchemy_to_dict(obj: Any) -> dict | None: 

74 """Transforms a SQLAlchemy model instance into a dictionary.""" 

75 if not obj: 

76 return None 

77 output = {} 

78 for col in obj.__table__.columns: 

79 value = getattr(obj, col.name) 

80 if isinstance(value, datetime): 

81 value = value.isoformat() 

82 output[col.name] = value 

83 return output 

84 

85 

86def ask_yesno(question: str, default: bool | None = None) -> bool: 

87 """Helper to get a yes or no answer from the user.""" 

88 yes = {"yes", "y"} 

89 no = {"no", "n"} 

90 

91 print(question) 

92 while True: 

93 choice = input().lower() 

94 if choice == "" and default is not None: 

95 return default 

96 if choice in yes: 

97 return True 

98 if choice in no: 

99 return False 

100 print("Please respond with y/yes or n/no.") 

101 

102 

103def prompt_with_timeout(question: str, timeout: int, default: bool | None = None) -> bool: 

104 """Ask the user a question and timeout if they don't respond.""" 

105 

106 def handler(signum, frame): 

107 raise AirflowException(f"Timeout {timeout}s reached") 

108 

109 signal.signal(signal.SIGALRM, handler) 

110 signal.alarm(timeout) 

111 try: 

112 return ask_yesno(question, default) 

113 finally: 

114 signal.alarm(0) 

115 

116 

117def is_container(obj: Any) -> bool: 

118 """Test if an object is a container (iterable) but not a string.""" 

119 return hasattr(obj, "__iter__") and not isinstance(obj, str) 

120 

121 

122def as_tuple(obj: Any) -> tuple: 

123 """ 

124 If obj is a container, returns obj as a tuple. 

125 Otherwise, returns a tuple containing obj. 

126 """ 

127 if is_container(obj): 

128 return tuple(obj) 

129 else: 

130 return tuple([obj]) 

131 

132 

133def chunks(items: list[T], chunk_size: int) -> Generator[list[T], None, None]: 

134 """Yield successive chunks of a given size from a list of items.""" 

135 if chunk_size <= 0: 

136 raise ValueError("Chunk size must be a positive integer") 

137 for i in range(0, len(items), chunk_size): 

138 yield items[i : i + chunk_size] 

139 

140 

141def reduce_in_chunks(fn: Callable[[S, list[T]], S], iterable: list[T], initializer: S, chunk_size: int = 0): 

142 """ 

143 Reduce the given list of items by splitting it into chunks 

144 of the given size and passing each chunk through the reducer. 

145 """ 

146 if len(iterable) == 0: 

147 return initializer 

148 if chunk_size == 0: 

149 chunk_size = len(iterable) 

150 return reduce(fn, chunks(iterable, chunk_size), initializer) 

151 

152 

153def as_flattened_list(iterable: Iterable[Iterable[T]]) -> list[T]: 

154 """ 

155 Return an iterable with one level flattened. 

156 

157 >>> as_flattened_list((('blue', 'red'), ('green', 'yellow', 'pink'))) 

158 ['blue', 'red', 'green', 'yellow', 'pink'] 

159 """ 

160 return [e for i in iterable for e in i] 

161 

162 

163def parse_template_string(template_string: str) -> tuple[str | None, jinja2.Template | None]: 

164 """Parses Jinja template string.""" 

165 import jinja2 

166 

167 if "{{" in template_string: # jinja mode 

168 return None, jinja2.Template(template_string) 

169 else: 

170 return template_string, None 

171 

172 

173def render_log_filename(ti: TaskInstance, try_number, filename_template) -> str: 

174 """ 

175 Given task instance, try_number, filename_template, return the rendered log 

176 filename. 

177 

178 :param ti: task instance 

179 :param try_number: try_number of the task 

180 :param filename_template: filename template, which can be jinja template or 

181 python string template 

182 """ 

183 filename_template, filename_jinja_template = parse_template_string(filename_template) 

184 if filename_jinja_template: 

185 jinja_context = ti.get_template_context() 

186 jinja_context["try_number"] = try_number 

187 return render_template_to_string(filename_jinja_template, jinja_context) 

188 

189 return filename_template.format( 

190 dag_id=ti.dag_id, 

191 task_id=ti.task_id, 

192 execution_date=ti.execution_date.isoformat(), 

193 try_number=try_number, 

194 ) 

195 

196 

197def convert_camel_to_snake(camel_str: str) -> str: 

198 """Converts CamelCase to snake_case.""" 

199 return CAMELCASE_TO_SNAKE_CASE_REGEX.sub(r"_\1", camel_str).lower() 

200 

201 

202def merge_dicts(dict1: dict, dict2: dict) -> dict: 

203 """ 

204 Merge two dicts recursively, returning new dict (input dict is not mutated). 

205 

206 Lists are not concatenated. Items in dict2 overwrite those also found in dict1. 

207 """ 

208 merged = dict1.copy() 

209 for k, v in dict2.items(): 

210 if k in merged and isinstance(v, dict): 

211 merged[k] = merge_dicts(merged.get(k, {}), v) 

212 else: 

213 merged[k] = v 

214 return merged 

215 

216 

217def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]: 

218 """Use a predicate to partition entries into false entries and true entries.""" 

219 iter_1, iter_2 = tee(iterable) 

220 return filterfalse(pred, iter_1), filter(pred, iter_2) 

221 

222 

223def chain(*args, **kwargs): 

224 """This function is deprecated. Please use `airflow.models.baseoperator.chain`.""" 

225 warnings.warn( 

226 "This function is deprecated. Please use `airflow.models.baseoperator.chain`.", 

227 RemovedInAirflow3Warning, 

228 stacklevel=2, 

229 ) 

230 return import_string("airflow.models.baseoperator.chain")(*args, **kwargs) 

231 

232 

233def cross_downstream(*args, **kwargs): 

234 """This function is deprecated. Please use `airflow.models.baseoperator.cross_downstream`.""" 

235 warnings.warn( 

236 "This function is deprecated. Please use `airflow.models.baseoperator.cross_downstream`.", 

237 RemovedInAirflow3Warning, 

238 stacklevel=2, 

239 ) 

240 return import_string("airflow.models.baseoperator.cross_downstream")(*args, **kwargs) 

241 

242 

243def build_airflow_url_with_query(query: dict[str, Any]) -> str: 

244 """ 

245 Build airflow url using base_url and default_view and provided query. 

246 

247 For example: 

248 http://0.0.0.0:8000/base/graph?dag_id=my-task&root=&execution_date=2020-10-27T10%3A59%3A25.615587 

249 """ 

250 import flask 

251 

252 view = conf.get_mandatory_value("webserver", "dag_default_view").lower() 

253 return flask.url_for(f"Airflow.{view}", **query) 

254 

255 

256# The 'template' argument is typed as Any because the jinja2.Template is too 

257# dynamic to be effectively type-checked. 

258def render_template(template: Any, context: MutableMapping[str, Any], *, native: bool) -> Any: 

259 """Render a Jinja2 template with given Airflow context. 

260 

261 The default implementation of ``jinja2.Template.render()`` converts the 

262 input context into dict eagerly many times, which triggers deprecation 

263 messages in our custom context class. This takes the implementation apart 

264 and retain the context mapping without resolving instead. 

265 

266 :param template: A Jinja2 template to render. 

267 :param context: The Airflow task context to render the template with. 

268 :param native: If set to *True*, render the template into a native type. A 

269 DAG can enable this with ``render_template_as_native_obj=True``. 

270 :returns: The render result. 

271 """ 

272 context = copy.copy(context) 

273 env = template.environment 

274 if template.globals: 

275 context.update((k, v) for k, v in template.globals.items() if k not in context) 

276 try: 

277 nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks)) 

278 except Exception: 

279 env.handle_exception() # Rewrite traceback to point to the template. 

280 if native: 

281 import jinja2.nativetypes 

282 

283 return jinja2.nativetypes.native_concat(nodes) 

284 return "".join(nodes) 

285 

286 

287def render_template_to_string(template: jinja2.Template, context: Context) -> str: 

288 """Shorthand to ``render_template(native=False)`` with better typing support.""" 

289 return render_template(template, cast(MutableMapping[str, Any], context), native=False) 

290 

291 

292def render_template_as_native(template: jinja2.Template, context: Context) -> Any: 

293 """Shorthand to ``render_template(native=True)`` with better typing support.""" 

294 return render_template(template, cast(MutableMapping[str, Any], context), native=True) 

295 

296 

297def exactly_one(*args) -> bool: 

298 """ 

299 Returns True if exactly one of *args is "truthy", and False otherwise. 

300 

301 If user supplies an iterable, we raise ValueError and force them to unpack. 

302 """ 

303 if is_container(args[0]): 

304 raise ValueError( 

305 "Not supported for iterable args. Use `*` to unpack your iterable in the function call." 

306 ) 

307 return sum(map(bool, args)) == 1 

308 

309 

310def at_most_one(*args) -> bool: 

311 """ 

312 Returns True if at most one of *args is "truthy", and False otherwise. 

313 

314 NOTSET is treated the same as None. 

315 

316 If user supplies an iterable, we raise ValueError and force them to unpack. 

317 """ 

318 

319 def is_set(val): 

320 if val is NOTSET: 

321 return False 

322 else: 

323 return bool(val) 

324 

325 return sum(map(is_set, args)) in (0, 1) 

326 

327 

328def prune_dict(val: Any, mode="strict"): 

329 """ 

330 Given dict ``val``, returns new dict based on ``val`` with all 

331 empty elements removed. 

332 

333 What constitutes "empty" is controlled by the ``mode`` parameter. If mode is 'strict' 

334 then only ``None`` elements will be removed. If mode is ``truthy``, then element ``x`` 

335 will be removed if ``bool(x) is False``. 

336 """ 

337 

338 def is_empty(x): 

339 if mode == "strict": 

340 return x is None 

341 elif mode == "truthy": 

342 return bool(x) is False 

343 raise ValueError("allowable values for `mode` include 'truthy' and 'strict'") 

344 

345 if isinstance(val, dict): 

346 new_dict = {} 

347 for k, v in val.items(): 

348 if is_empty(v): 

349 continue 

350 elif isinstance(v, (list, dict)): 

351 new_val = prune_dict(v, mode=mode) 

352 if new_val: 

353 new_dict[k] = new_val 

354 else: 

355 new_dict[k] = v 

356 return new_dict 

357 elif isinstance(val, list): 

358 new_list = [] 

359 for v in val: 

360 if is_empty(v): 

361 continue 

362 elif isinstance(v, (list, dict)): 

363 new_val = prune_dict(v, mode=mode) 

364 if new_val: 

365 new_list.append(new_val) 

366 else: 

367 new_list.append(v) 

368 return new_list 

369 else: 

370 return val 

371 

372 

373def prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None: 

374 """Ensure *kwargs1* and *kwargs2* do not contain common keys. 

375 

376 :raises TypeError: If common keys are found. 

377 """ 

378 duplicated_keys = set(kwargs1).intersection(kwargs2) 

379 if not duplicated_keys: 

380 return 

381 if len(duplicated_keys) == 1: 

382 raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") 

383 duplicated_keys_display = ", ".join(sorted(duplicated_keys)) 

384 raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")