1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20from typing import TYPE_CHECKING
21
22from airflow.providers.common.compat._compat_utils import create_module_getattr
23from airflow.providers.common.compat.version_compat import (
24 AIRFLOW_V_3_1_PLUS,
25 AIRFLOW_V_3_2_PLUS,
26)
27
28_IMPORT_MAP: dict[str, str | tuple[str, ...]] = {
29 # Re-export from sdk (which handles Airflow 2.x/3.x fallbacks)
30 "BaseOperator": "airflow.providers.common.compat.sdk",
31 "BaseAsyncOperator": "airflow.providers.common.compat.sdk",
32 "get_current_context": "airflow.providers.common.compat.sdk",
33 "is_async_callable": "airflow.providers.common.compat.sdk",
34 # Standard provider items with direct fallbacks
35 "PythonOperator": ("airflow.providers.standard.operators.python", "airflow.operators.python"),
36 "ShortCircuitOperator": ("airflow.providers.standard.operators.python", "airflow.operators.python"),
37 "_SERIALIZERS": ("airflow.providers.standard.operators.python", "airflow.operators.python"),
38}
39
40if TYPE_CHECKING:
41 from airflow.sdk.bases.decorator import is_async_callable
42 from airflow.sdk.bases.operator import BaseAsyncOperator
43elif AIRFLOW_V_3_2_PLUS:
44 from airflow.sdk.bases.decorator import is_async_callable
45 from airflow.sdk.bases.operator import BaseAsyncOperator
46else:
47 if AIRFLOW_V_3_1_PLUS:
48 from airflow.sdk import BaseOperator
49 else:
50 from airflow.models import BaseOperator
51
52 def is_async_callable(func) -> bool:
53 """Detect if a callable is an async function."""
54 import inspect
55 from functools import partial
56
57 while isinstance(func, partial):
58 func = func.func
59 return inspect.iscoroutinefunction(func)
60
61 class BaseAsyncOperator(BaseOperator):
62 """Stub for Airflow < 3.2 that raises a clear error."""
63
64 @property
65 def is_async(self) -> bool:
66 return True
67
68 async def aexecute(self, context):
69 raise NotImplementedError()
70
71 def execute(self, context):
72 raise RuntimeError(
73 "Async operators require Airflow 3.2+. Upgrade Airflow or use a synchronous callable."
74 )
75
76
77__getattr__ = create_module_getattr(import_map=_IMPORT_MAP)
78
79__all__ = sorted(_IMPORT_MAP.keys())