Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/utils/session.py: 56%

39 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import contextlib 

20from functools import wraps 

21from inspect import signature 

22from typing import Callable, Generator, TypeVar, cast 

23 

24from airflow import settings 

25from airflow.typing_compat import ParamSpec 

26 

27 

28@contextlib.contextmanager 

29def create_session() -> Generator[settings.SASession, None, None]: 

30 """Contextmanager that will create and teardown a session.""" 

31 Session = getattr(settings, "Session", None) 

32 if Session is None: 

33 raise RuntimeError("Session must be set before!") 

34 session = Session() 

35 try: 

36 yield session 

37 session.commit() 

38 except Exception: 

39 session.rollback() 

40 raise 

41 finally: 

42 session.close() 

43 

44 

45PS = ParamSpec("PS") 

46RT = TypeVar("RT") 

47 

48 

49def find_session_idx(func: Callable[PS, RT]) -> int: 

50 """Find session index in function call parameter.""" 

51 func_params = signature(func).parameters 

52 try: 

53 # func_params is an ordered dict -- this is the "recommended" way of getting the position 

54 session_args_idx = tuple(func_params).index("session") 

55 except ValueError: 

56 raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None 

57 

58 return session_args_idx 

59 

60 

61def provide_session(func: Callable[PS, RT]) -> Callable[PS, RT]: 

62 """ 

63 Function decorator that provides a session if it isn't provided. 

64 If you want to reuse a session or run the function as part of a 

65 database transaction, you pass it to the function, if not this wrapper 

66 will create one and close it for you. 

67 """ 

68 session_args_idx = find_session_idx(func) 

69 

70 @wraps(func) 

71 def wrapper(*args, **kwargs) -> RT: 

72 if "session" in kwargs or session_args_idx < len(args): 

73 return func(*args, **kwargs) 

74 else: 

75 with create_session() as session: 

76 return func(*args, session=session, **kwargs) 

77 

78 return wrapper 

79 

80 

81# A fake session to use in functions decorated by provide_session. This allows 

82# the 'session' argument to be of type Session instead of Session | None, 

83# making it easier to type hint the function body without dealing with the None 

84# case that can never happen at runtime. 

85NEW_SESSION: settings.SASession = cast(settings.SASession, None)