Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/utils/session.py: 56%
39 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import contextlib
20from functools import wraps
21from inspect import signature
22from typing import Callable, Generator, TypeVar, cast
24from airflow import settings
25from airflow.typing_compat import ParamSpec
28@contextlib.contextmanager
29def create_session() -> Generator[settings.SASession, None, None]:
30 """Contextmanager that will create and teardown a session."""
31 Session = getattr(settings, "Session", None)
32 if Session is None:
33 raise RuntimeError("Session must be set before!")
34 session = Session()
35 try:
36 yield session
37 session.commit()
38 except Exception:
39 session.rollback()
40 raise
41 finally:
42 session.close()
45PS = ParamSpec("PS")
46RT = TypeVar("RT")
49def find_session_idx(func: Callable[PS, RT]) -> int:
50 """Find session index in function call parameter."""
51 func_params = signature(func).parameters
52 try:
53 # func_params is an ordered dict -- this is the "recommended" way of getting the position
54 session_args_idx = tuple(func_params).index("session")
55 except ValueError:
56 raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None
58 return session_args_idx
61def provide_session(func: Callable[PS, RT]) -> Callable[PS, RT]:
62 """
63 Function decorator that provides a session if it isn't provided.
64 If you want to reuse a session or run the function as part of a
65 database transaction, you pass it to the function, if not this wrapper
66 will create one and close it for you.
67 """
68 session_args_idx = find_session_idx(func)
70 @wraps(func)
71 def wrapper(*args, **kwargs) -> RT:
72 if "session" in kwargs or session_args_idx < len(args):
73 return func(*args, **kwargs)
74 else:
75 with create_session() as session:
76 return func(*args, session=session, **kwargs)
78 return wrapper
81# A fake session to use in functions decorated by provide_session. This allows
82# the 'session' argument to be of type Session instead of Session | None,
83# making it easier to type hint the function body without dealing with the None
84# case that can never happen at runtime.
85NEW_SESSION: settings.SASession = cast(settings.SASession, None)