Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/utils/retries.py: 53%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import functools
20import logging
21from inspect import signature
22from typing import Callable, TypeVar, overload
24from sqlalchemy.exc import DBAPIError, OperationalError
26from airflow.configuration import conf
28F = TypeVar("F", bound=Callable)
30MAX_DB_RETRIES = conf.getint("database", "max_db_retries", fallback=3)
33def run_with_db_retries(max_retries: int = MAX_DB_RETRIES, logger: logging.Logger | None = None, **kwargs):
34 """Return Tenacity Retrying object with project specific default."""
35 import tenacity
37 # Default kwargs
38 retry_kwargs = dict(
39 retry=tenacity.retry_if_exception_type(exception_types=(OperationalError, DBAPIError)),
40 wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
41 stop=tenacity.stop_after_attempt(max_retries),
42 reraise=True,
43 **kwargs,
44 )
45 if logger and isinstance(logger, logging.Logger):
46 retry_kwargs["before_sleep"] = tenacity.before_sleep_log(logger, logging.DEBUG, True)
48 return tenacity.Retrying(**retry_kwargs)
51@overload
52def retry_db_transaction(*, retries: int = MAX_DB_RETRIES) -> Callable[[F], F]: ...
55@overload
56def retry_db_transaction(_func: F) -> F: ...
59def retry_db_transaction(_func: Callable | None = None, *, retries: int = MAX_DB_RETRIES, **retry_kwargs):
60 """
61 Retry functions in case of ``OperationalError`` from DB.
63 It should not be used with ``@provide_session``.
64 """
66 def retry_decorator(func: Callable) -> Callable:
67 # Get Positional argument for 'session'
68 func_params = signature(func).parameters
69 try:
70 # func_params is an ordered dict -- this is the "recommended" way of getting the position
71 session_args_idx = tuple(func_params).index("session")
72 except ValueError:
73 raise ValueError(f"Function {func.__qualname__} has no `session` argument")
74 # We don't need this anymore -- ensure we don't keep a reference to it by mistake
75 del func_params
77 @functools.wraps(func)
78 def wrapped_function(*args, **kwargs):
79 logger = args[0].log if args and hasattr(args[0], "log") else logging.getLogger(func.__module__)
81 # Get session from args or kwargs
82 if "session" in kwargs:
83 session = kwargs["session"]
84 elif len(args) > session_args_idx:
85 session = args[session_args_idx]
86 else:
87 raise TypeError(f"session is a required argument for {func.__qualname__}")
89 for attempt in run_with_db_retries(max_retries=retries, logger=logger, **retry_kwargs):
90 with attempt:
91 logger.debug(
92 "Running %s with retries. Try %d of %d",
93 func.__qualname__,
94 attempt.retry_state.attempt_number,
95 retries,
96 )
97 try:
98 return func(*args, **kwargs)
99 except OperationalError:
100 session.rollback()
101 raise
103 return wrapped_function
105 # Allow using decorator with and without arguments
106 if _func is None:
107 return retry_decorator
108 else:
109 return retry_decorator(_func)