Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/utils/session.py: 58%

38 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import contextlib 

20from functools import wraps 

21from inspect import signature 

22from typing import Callable, Generator, TypeVar, cast 

23 

24from airflow import settings 

25from airflow.typing_compat import ParamSpec 

26 

27 

28@contextlib.contextmanager 

29def create_session() -> Generator[settings.SASession, None, None]: 

30 """Contextmanager that will create and teardown a session.""" 

31 if not settings.Session: 

32 raise RuntimeError("Session must be set before!") 

33 session = settings.Session() 

34 try: 

35 yield session 

36 session.commit() 

37 except Exception: 

38 session.rollback() 

39 raise 

40 finally: 

41 session.close() 

42 

43 

44PS = ParamSpec("PS") 

45RT = TypeVar("RT") 

46 

47 

48def find_session_idx(func: Callable[PS, RT]) -> int: 

49 """Find session index in function call parameter.""" 

50 func_params = signature(func).parameters 

51 try: 

52 # func_params is an ordered dict -- this is the "recommended" way of getting the position 

53 session_args_idx = tuple(func_params).index("session") 

54 except ValueError: 

55 raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None 

56 

57 return session_args_idx 

58 

59 

60def provide_session(func: Callable[PS, RT]) -> Callable[PS, RT]: 

61 """ 

62 Function decorator that provides a session if it isn't provided. 

63 If you want to reuse a session or run the function as part of a 

64 database transaction, you pass it to the function, if not this wrapper 

65 will create one and close it for you. 

66 """ 

67 session_args_idx = find_session_idx(func) 

68 

69 @wraps(func) 

70 def wrapper(*args, **kwargs) -> RT: 

71 if "session" in kwargs or session_args_idx < len(args): 

72 return func(*args, **kwargs) 

73 else: 

74 with create_session() as session: 

75 return func(*args, session=session, **kwargs) 

76 

77 return wrapper 

78 

79 

80# A fake session to use in functions decorated by provide_session. This allows 

81# the 'session' argument to be of type Session instead of Session | None, 

82# making it easier to type hint the function body without dealing with the None 

83# case that can never happen at runtime. 

84NEW_SESSION: settings.SASession = cast(settings.SASession, None)