Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/utils/retries.py: 31%
49 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import functools
20import logging
21from inspect import signature
22from typing import Callable, TypeVar, overload
24from sqlalchemy.exc import DBAPIError, OperationalError
26from airflow.configuration import conf
28F = TypeVar("F", bound=Callable)
30MAX_DB_RETRIES = conf.getint("database", "max_db_retries", fallback=3)
33def run_with_db_retries(max_retries: int = MAX_DB_RETRIES, logger: logging.Logger | None = None, **kwargs):
34 """Return Tenacity Retrying object with project specific default."""
35 import tenacity
37 # Default kwargs
38 retry_kwargs = dict(
39 retry=tenacity.retry_if_exception_type(exception_types=(OperationalError, DBAPIError)),
40 wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
41 stop=tenacity.stop_after_attempt(max_retries),
42 reraise=True,
43 **kwargs,
44 )
45 if logger and isinstance(logger, logging.Logger):
46 retry_kwargs["before_sleep"] = tenacity.before_sleep_log(logger, logging.DEBUG, True)
48 return tenacity.Retrying(**retry_kwargs)
51@overload
52def retry_db_transaction(*, retries: int = MAX_DB_RETRIES) -> Callable[[F], F]:
53 ...
56@overload
57def retry_db_transaction(_func: F) -> F:
58 ...
61def retry_db_transaction(_func: Callable | None = None, *, retries: int = MAX_DB_RETRIES, **retry_kwargs):
62 """Decorator to retry functions in case of ``OperationalError`` from DB.
64 It should not be used with ``@provide_session``.
65 """
67 def retry_decorator(func: Callable) -> Callable:
68 # Get Positional argument for 'session'
69 func_params = signature(func).parameters
70 try:
71 # func_params is an ordered dict -- this is the "recommended" way of getting the position
72 session_args_idx = tuple(func_params).index("session")
73 except ValueError:
74 raise ValueError(f"Function {func.__qualname__} has no `session` argument")
75 # We don't need this anymore -- ensure we don't keep a reference to it by mistake
76 del func_params
78 @functools.wraps(func)
79 def wrapped_function(*args, **kwargs):
80 logger = args[0].log if args and hasattr(args[0], "log") else logging.getLogger(func.__module__)
82 # Get session from args or kwargs
83 if "session" in kwargs:
84 session = kwargs["session"]
85 elif len(args) > session_args_idx:
86 session = args[session_args_idx]
87 else:
88 raise TypeError(f"session is a required argument for {func.__qualname__}")
90 for attempt in run_with_db_retries(max_retries=retries, logger=logger, **retry_kwargs):
91 with attempt:
92 logger.debug(
93 "Running %s with retries. Try %d of %d",
94 func.__qualname__,
95 attempt.retry_state.attempt_number,
96 retries,
97 )
98 try:
99 return func(*args, **kwargs)
100 except OperationalError:
101 session.rollback()
102 raise
104 return wrapped_function
106 # Allow using decorator with and without arguments
107 if _func is None:
108 return retry_decorator
109 else:
110 return retry_decorator(_func)