1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import contextlib
20from collections.abc import Callable, Generator
21from functools import wraps
22from inspect import signature
23from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast
24
25from airflow import settings
26
27if TYPE_CHECKING:
28 from sqlalchemy.orm import Session as SASession
29
30
31@contextlib.contextmanager
32def create_session(scoped: bool = True) -> Generator[SASession, None, None]:
33 """Contextmanager that will create and teardown a session."""
34 if scoped:
35 Session = getattr(settings, "Session", None)
36 else:
37 Session = getattr(settings, "NonScopedSession", None)
38 if Session is None:
39 raise RuntimeError("Session must be set before!")
40 session = Session()
41 try:
42 yield session
43 session.commit()
44 except Exception:
45 session.rollback()
46 raise
47 finally:
48 session.close()
49
50
51@contextlib.asynccontextmanager
52async def create_session_async():
53 """
54 Context manager to create async session.
55
56 :meta private:
57 """
58 from airflow.settings import AsyncSession
59
60 async with AsyncSession() as session:
61 try:
62 yield session
63 await session.commit()
64 except Exception:
65 await session.rollback()
66 raise
67
68
69PS = ParamSpec("PS")
70RT = TypeVar("RT")
71
72
73def find_session_idx(func: Callable[PS, RT]) -> int:
74 """Find session index in function call parameter."""
75 func_params = signature(func).parameters
76 try:
77 # func_params is an ordered dict -- this is the "recommended" way of getting the position
78 session_args_idx = tuple(func_params).index("session")
79 except ValueError:
80 raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None
81
82 return session_args_idx
83
84
85def provide_session(func: Callable[PS, RT]) -> Callable[PS, RT]:
86 """
87 Provide a session if it isn't provided.
88
89 If you want to reuse a session or run the function as part of a
90 database transaction, you pass it to the function, if not this wrapper
91 will create one and close it for you.
92 """
93 session_args_idx = find_session_idx(func)
94
95 @wraps(func)
96 def wrapper(*args, **kwargs) -> RT:
97 if "session" in kwargs or session_args_idx < len(args):
98 return func(*args, **kwargs)
99 with create_session() as session:
100 return func(*args, session=session, **kwargs) # type: ignore[arg-type]
101
102 return wrapper
103
104
105# A fake session to use in functions decorated by provide_session. This allows
106# the 'session' argument to be of type Session instead of Session | None,
107# making it easier to type hint the function body without dealing with the None
108# case that can never happen at runtime.
109NEW_SESSION: SASession = cast("SASession", None)