1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Manages runtime provider resources for task execution."""
19
20from __future__ import annotations
21
22import functools
23import inspect
24import traceback
25import warnings
26from collections.abc import Callable, MutableMapping
27from typing import TYPE_CHECKING, Any
28from urllib.parse import SplitResult
29
30import structlog
31
32from airflow.sdk._shared.module_loading import import_string
33from airflow.sdk._shared.providers_discovery import (
34 KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS,
35 HookClassProvider,
36 HookInfo,
37 LazyDictWithCache,
38 PluginInfo,
39 ProviderInfo,
40 _check_builtin_provider_prefix,
41 _create_provider_info_schema_validator,
42 discover_all_providers_from_packages,
43 log_import_warning,
44 log_optional_feature_disabled,
45 provider_info_cache,
46)
47from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin
48from airflow.sdk.exceptions import AirflowOptionalProviderFeatureException
49
50if TYPE_CHECKING:
51 from airflow.sdk import BaseHook
52 from airflow.sdk.bases.decorator import TaskDecorator
53 from airflow.sdk.definitions.asset import Asset
54
55log = structlog.getLogger(__name__)
56
57
58def _correctness_check(provider_package: str, class_name: str, provider_info: ProviderInfo) -> Any:
59 """
60 Perform coherence check on provider classes.
61
62 For apache-airflow providers - it checks if it starts with appropriate package. For all providers
63 it tries to import the provider - checking that there are no exceptions during importing.
64 It logs appropriate warning in case it detects any problems.
65
66 :param provider_package: name of the provider package
67 :param class_name: name of the class to import
68
69 :return the class if the class is OK, None otherwise.
70 """
71 if not _check_builtin_provider_prefix(provider_package, class_name):
72 return None
73 try:
74 imported_class = import_string(class_name)
75 except AirflowOptionalProviderFeatureException as e:
76 # When the provider class raises AirflowOptionalProviderFeatureException
77 # this is an expected case when only some classes in provider are
78 # available. We just log debug level here and print info message in logs so that
79 # the user is aware of it
80 log_optional_feature_disabled(class_name, e, provider_package)
81 return None
82 except ImportError as e:
83 if "No module named 'airflow.providers." in e.msg:
84 # handle cases where another provider is missing. This can only happen if
85 # there is an optional feature, so we log debug and print information about it
86 log_optional_feature_disabled(class_name, e, provider_package)
87 return None
88 for known_error in KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS:
89 # Until we convert all providers to use AirflowOptionalProviderFeatureException
90 # we assume any problem with importing another "provider" is because this is an
91 # optional feature, so we log debug and print information about it
92 if known_error[0] == provider_package and known_error[1] in e.msg:
93 log_optional_feature_disabled(class_name, e, provider_package)
94 return None
95 # But when we have no idea - we print warning to logs
96 log_import_warning(class_name, e, provider_package)
97 return None
98 except Exception as e:
99 log_import_warning(class_name, e, provider_package)
100 return None
101 return imported_class
102
103
104class ProvidersManagerTaskRuntime(LoggingMixin):
105 """
106 Manages runtime provider resources for task execution.
107
108 This is a Singleton class. The first time it is instantiated, it discovers all available
109 runtime provider resources (hooks, taskflow decorators, filesystems, asset handlers).
110 """
111
112 resource_version = "0"
113 _initialized: bool = False
114 _initialization_stack_trace = None
115 _instance: ProvidersManagerTaskRuntime | None = None
116
117 def __new__(cls):
118 if cls._instance is None:
119 cls._instance = super().__new__(cls)
120 return cls._instance
121
122 @staticmethod
123 def initialized() -> bool:
124 return ProvidersManagerTaskRuntime._initialized
125
126 @staticmethod
127 def initialization_stack_trace() -> str | None:
128 return ProvidersManagerTaskRuntime._initialization_stack_trace
129
130 def __init__(self):
131 """Initialize the runtime manager."""
132 # skip initialization if already initialized
133 if self.initialized():
134 return
135 super().__init__()
136 ProvidersManagerTaskRuntime._initialized = True
137 ProvidersManagerTaskRuntime._initialization_stack_trace = "".join(
138 traceback.format_stack(inspect.currentframe())
139 )
140 self._initialized_cache: dict[str, bool] = {}
141 # Keeps dict of providers keyed by module name
142 self._provider_dict: dict[str, ProviderInfo] = {}
143 self._fs_set: set[str] = set()
144 self._asset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {}
145 self._asset_factories: dict[str, Callable[..., Asset]] = {}
146 self._asset_to_openlineage_converters: dict[str, Callable] = {}
147 self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache()
148 # keeps mapping between connection_types and hook class, package they come from
149 self._hook_provider_dict: dict[str, HookClassProvider] = {}
150 # Keeps dict of hooks keyed by connection type. They are lazy evaluated at access time
151 self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] = LazyDictWithCache()
152 self._plugins_set: set[PluginInfo] = set()
153 self._provider_schema_validator = _create_provider_info_schema_validator()
154 self._init_airflow_core_hooks()
155
156 def _init_airflow_core_hooks(self):
157 """Initialize the hooks dict with default hooks from Airflow core."""
158 core_dummy_hooks = {
159 "generic": "Generic",
160 "email": "Email",
161 }
162 for key, display in core_dummy_hooks.items():
163 self._hooks_lazy_dict[key] = HookInfo(
164 hook_class_name=None,
165 connection_id_attribute_name=None,
166 package_name=None,
167 hook_name=display,
168 connection_type=None,
169 connection_testable=False,
170 )
171 for conn_type, class_name in (
172 ("fs", "airflow.providers.standard.hooks.filesystem.FSHook"),
173 ("package_index", "airflow.providers.standard.hooks.package_index.PackageIndexHook"),
174 ):
175 self._hooks_lazy_dict[conn_type] = functools.partial(
176 self._import_hook,
177 connection_type=None,
178 package_name="apache-airflow-providers-standard",
179 hook_class_name=class_name,
180 provider_info=None,
181 )
182
183 @provider_info_cache("list")
184 def initialize_providers_list(self):
185 """Lazy initialization of providers list."""
186 discover_all_providers_from_packages(self._provider_dict, self._provider_schema_validator)
187 self._provider_dict = dict(sorted(self._provider_dict.items()))
188
189 @provider_info_cache("hooks")
190 def initialize_providers_hooks(self):
191 """Lazy initialization of providers hooks."""
192 self._init_airflow_core_hooks()
193 self.initialize_providers_list()
194 self._discover_hooks()
195 self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items()))
196
197 @provider_info_cache("filesystems")
198 def initialize_providers_filesystems(self):
199 """Lazy initialization of providers filesystems."""
200 self.initialize_providers_list()
201 self._discover_filesystems()
202
203 @provider_info_cache("asset_uris")
204 def initialize_providers_asset_uri_resources(self):
205 """Lazy initialization of provider asset URI handlers, factories, converters etc."""
206 self.initialize_providers_list()
207 self._discover_asset_uri_resources()
208
209 @provider_info_cache("plugins")
210 def initialize_providers_plugins(self):
211 """Lazy initialization of providers plugins."""
212 self.initialize_providers_list()
213 self._discover_plugins()
214
215 @provider_info_cache("taskflow_decorators")
216 def initialize_providers_taskflow_decorator(self):
217 """Lazy initialization of providers taskflow decorators."""
218 self.initialize_providers_list()
219 self._discover_taskflow_decorators()
220
221 def _discover_hooks_from_connection_types(
222 self,
223 hook_class_names_registered: set[str],
224 already_registered_warning_connection_types: set[str],
225 package_name: str,
226 provider: ProviderInfo,
227 ):
228 """
229 Discover hooks from the "connection-types" property.
230
231 This is new, better method that replaces discovery from hook-class-names as it
232 allows to lazy import individual Hook classes when they are accessed.
233 The "connection-types" keeps information about both - connection type and class
234 name so we can discover all connection-types without importing the classes.
235 :param hook_class_names_registered: set of registered hook class names for this provider
236 :param already_registered_warning_connection_types: set of connections for which warning should be
237 printed in logs as they were already registered before
238 :param package_name:
239 :param provider:
240 :return:
241 """
242 provider_uses_connection_types = False
243 connection_types = provider.data.get("connection-types")
244 if connection_types:
245 for connection_type_dict in connection_types:
246 connection_type = connection_type_dict["connection-type"]
247 hook_class_name = connection_type_dict["hook-class-name"]
248 hook_class_names_registered.add(hook_class_name)
249 already_registered = self._hook_provider_dict.get(connection_type)
250 if already_registered:
251 if already_registered.package_name != package_name:
252 already_registered_warning_connection_types.add(connection_type)
253 else:
254 log.warning(
255 "The connection type '%s' is already registered in the"
256 " package '%s' with different class names: '%s' and '%s'. ",
257 connection_type,
258 package_name,
259 already_registered.hook_class_name,
260 hook_class_name,
261 )
262 else:
263 self._hook_provider_dict[connection_type] = HookClassProvider(
264 hook_class_name=hook_class_name, package_name=package_name
265 )
266 # Defer importing hook to access time by setting import hook method as dict value
267 self._hooks_lazy_dict[connection_type] = functools.partial(
268 self._import_hook,
269 connection_type=connection_type,
270 provider_info=provider,
271 )
272 provider_uses_connection_types = True
273 return provider_uses_connection_types
274
275 def _discover_hooks_from_hook_class_names(
276 self,
277 hook_class_names_registered: set[str],
278 already_registered_warning_connection_types: set[str],
279 package_name: str,
280 provider: ProviderInfo,
281 provider_uses_connection_types: bool,
282 ):
283 """
284 Discover hooks from "hook-class-names' property.
285
286 This property is deprecated but we should support it in Airflow 2.
287 The hook-class-names array contained just Hook names without connection type,
288 therefore we need to import all those classes immediately to know which connection types
289 are supported. This makes it impossible to selectively only import those hooks that are used.
290 :param already_registered_warning_connection_types: list of connection hooks that we should warn
291 about when finished discovery
292 :param package_name: name of the provider package
293 :param provider: class that keeps information about version and details of the provider
294 :param provider_uses_connection_types: determines whether the provider uses "connection-types" new
295 form of passing connection types
296 :return:
297 """
298 hook_class_names = provider.data.get("hook-class-names")
299 if hook_class_names:
300 for hook_class_name in hook_class_names:
301 if hook_class_name in hook_class_names_registered:
302 # Silently ignore the hook class - it's already marked for lazy-import by
303 # connection-types discovery
304 continue
305 hook_info = self._import_hook(
306 connection_type=None,
307 provider_info=provider,
308 hook_class_name=hook_class_name,
309 package_name=package_name,
310 )
311 if not hook_info:
312 # Problem why importing class - we ignore it. Log is written at import time
313 continue
314 already_registered = self._hook_provider_dict.get(hook_info.connection_type)
315 if already_registered:
316 if already_registered.package_name != package_name:
317 already_registered_warning_connection_types.add(hook_info.connection_type)
318 else:
319 if already_registered.hook_class_name != hook_class_name:
320 log.warning(
321 "The hook connection type '%s' is registered twice in the"
322 " package '%s' with different class names: '%s' and '%s'. "
323 " Please fix it!",
324 hook_info.connection_type,
325 package_name,
326 already_registered.hook_class_name,
327 hook_class_name,
328 )
329 else:
330 self._hook_provider_dict[hook_info.connection_type] = HookClassProvider(
331 hook_class_name=hook_class_name, package_name=package_name
332 )
333 self._hooks_lazy_dict[hook_info.connection_type] = hook_info
334
335 if not provider_uses_connection_types:
336 warnings.warn(
337 f"The provider {package_name} uses `hook-class-names` "
338 "property in provider-info and has no `connection-types` one. "
339 "The 'hook-class-names' property has been deprecated in favour "
340 "of 'connection-types' in Airflow 2.2. Use **both** in case you want to "
341 "have backwards compatibility with Airflow < 2.2",
342 DeprecationWarning,
343 stacklevel=1,
344 )
345 for already_registered_connection_type in already_registered_warning_connection_types:
346 log.warning(
347 "The connection_type '%s' has been already registered by provider '%s.'",
348 already_registered_connection_type,
349 self._hook_provider_dict[already_registered_connection_type].package_name,
350 )
351
352 def _discover_hooks(self) -> None:
353 """Retrieve all connections defined in the providers via Hooks."""
354 for package_name, provider in self._provider_dict.items():
355 duplicated_connection_types: set[str] = set()
356 hook_class_names_registered: set[str] = set()
357 provider_uses_connection_types = self._discover_hooks_from_connection_types(
358 hook_class_names_registered, duplicated_connection_types, package_name, provider
359 )
360 self._discover_hooks_from_hook_class_names(
361 hook_class_names_registered,
362 duplicated_connection_types,
363 package_name,
364 provider,
365 provider_uses_connection_types,
366 )
367 self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items()))
368
369 @staticmethod
370 def _get_attr(obj: Any, attr_name: str):
371 """Retrieve attributes of an object, or warn if not found."""
372 if not hasattr(obj, attr_name):
373 log.warning("The object '%s' is missing %s attribute and cannot be registered", obj, attr_name)
374 return None
375 return getattr(obj, attr_name)
376
377 def _import_hook(
378 self,
379 connection_type: str | None,
380 provider_info: ProviderInfo,
381 hook_class_name: str | None = None,
382 package_name: str | None = None,
383 ) -> HookInfo | None:
384 """
385 Import hook and retrieve hook information.
386
387 Either connection_type (for lazy loading) or hook_class_name must be set - but not both).
388 Only needs package_name if hook_class_name is passed (for lazy loading, package_name
389 is retrieved from _connection_type_class_provider_dict together with hook_class_name).
390
391 :param connection_type: type of the connection
392 :param hook_class_name: name of the hook class
393 :param package_name: provider package - only needed in case connection_type is missing
394 : return
395 """
396 if connection_type is None and hook_class_name is None:
397 raise ValueError("Either connection_type or hook_class_name must be set")
398 if connection_type is not None and hook_class_name is not None:
399 raise ValueError(
400 f"Both connection_type ({connection_type} and "
401 f"hook_class_name {hook_class_name} are set. Only one should be set!"
402 )
403 if connection_type is not None:
404 class_provider = self._hook_provider_dict[connection_type]
405 package_name = class_provider.package_name
406 hook_class_name = class_provider.hook_class_name
407 else:
408 if not hook_class_name:
409 raise ValueError("Either connection_type or hook_class_name must be set")
410 if not package_name:
411 raise ValueError(
412 f"Provider package name is not set when hook_class_name ({hook_class_name}) is used"
413 )
414 hook_class: type[BaseHook] | None = _correctness_check(package_name, hook_class_name, provider_info)
415 if hook_class is None:
416 return None
417
418 hook_connection_type = self._get_attr(hook_class, "conn_type")
419 if connection_type:
420 if hook_connection_type != connection_type:
421 log.warning(
422 "Inconsistency! The hook class '%s' declares connection type '%s'"
423 " but it is added by provider '%s' as connection_type '%s' in provider info. "
424 "This should be fixed!",
425 hook_class,
426 hook_connection_type,
427 package_name,
428 connection_type,
429 )
430 connection_type = hook_connection_type
431 connection_id_attribute_name: str = self._get_attr(hook_class, "conn_name_attr")
432 hook_name: str = self._get_attr(hook_class, "hook_name")
433
434 if not connection_type or not connection_id_attribute_name or not hook_name:
435 log.warning(
436 "The hook misses one of the key attributes: "
437 "conn_type: %s, conn_id_attribute_name: %s, hook_name: %s",
438 connection_type,
439 connection_id_attribute_name,
440 hook_name,
441 )
442 return None
443
444 return HookInfo(
445 hook_class_name=hook_class_name,
446 connection_id_attribute_name=connection_id_attribute_name,
447 package_name=package_name,
448 hook_name=hook_name,
449 connection_type=connection_type,
450 connection_testable=hasattr(hook_class, "test_connection"),
451 )
452
453 def _discover_filesystems(self) -> None:
454 """Retrieve all filesystems defined in the providers."""
455 for provider_package, provider in self._provider_dict.items():
456 for fs_module_name in provider.data.get("filesystems", []):
457 if _correctness_check(provider_package, f"{fs_module_name}.get_fs", provider):
458 self._fs_set.add(fs_module_name)
459 self._fs_set = set(sorted(self._fs_set))
460
461 def _discover_asset_uri_resources(self) -> None:
462 """Discovers and registers asset URI handlers, factories, and converters for all providers."""
463 from airflow.sdk.definitions.asset import normalize_noop
464
465 def _safe_register_resource(
466 provider_package_name: str,
467 schemes_list: list[str],
468 resource_path: str | None,
469 resource_registry: dict,
470 default_resource: Any = None,
471 ):
472 """
473 Register a specific resource (handler, factory, or converter) for the given schemes.
474
475 If the resolved resource (either from the path or the default) is valid, it updates
476 the resource registry with the appropriate resource for each scheme.
477 """
478 resource = (
479 _correctness_check(provider_package_name, resource_path, provider)
480 if resource_path is not None
481 else default_resource
482 )
483 if resource:
484 resource_registry.update((scheme, resource) for scheme in schemes_list)
485
486 for provider_name, provider in self._provider_dict.items():
487 for uri_info in provider.data.get("asset-uris", []):
488 if "schemes" not in uri_info or "handler" not in uri_info:
489 continue # Both schemas and handler must be explicitly set, handler can be set to null
490 common_args = {"schemes_list": uri_info["schemes"], "provider_package_name": provider_name}
491 _safe_register_resource(
492 resource_path=uri_info["handler"],
493 resource_registry=self._asset_uri_handlers,
494 default_resource=normalize_noop,
495 **common_args,
496 )
497 _safe_register_resource(
498 resource_path=uri_info.get("factory"),
499 resource_registry=self._asset_factories,
500 **common_args,
501 )
502 _safe_register_resource(
503 resource_path=uri_info.get("to_openlineage_converter"),
504 resource_registry=self._asset_to_openlineage_converters,
505 **common_args,
506 )
507
508 def _discover_plugins(self) -> None:
509 """Retrieve all plugins defined in the providers."""
510 for provider_package, provider in self._provider_dict.items():
511 for plugin_dict in provider.data.get("plugins", ()):
512 if not _correctness_check(provider_package, plugin_dict["plugin-class"], provider):
513 log.warning("Plugin not loaded due to above correctness check problem.")
514 continue
515 self._plugins_set.add(
516 PluginInfo(
517 name=plugin_dict["name"],
518 plugin_class=plugin_dict["plugin-class"],
519 provider_name=provider_package,
520 )
521 )
522
523 def _discover_taskflow_decorators(self) -> None:
524 for name, info in self._provider_dict.items():
525 for taskflow_decorator in info.data.get("task-decorators", []):
526 self._add_taskflow_decorator(
527 taskflow_decorator["name"], taskflow_decorator["class-name"], name
528 )
529
530 def _add_taskflow_decorator(self, name, decorator_class_name: str, provider_package: str) -> None:
531 if not _check_builtin_provider_prefix(provider_package, decorator_class_name):
532 return
533
534 if name in self._taskflow_decorators:
535 try:
536 existing = self._taskflow_decorators[name]
537 other_name = f"{existing.__module__}.{existing.__name__}"
538 except Exception:
539 # If problem importing, then get the value from the functools.partial
540 other_name = self._taskflow_decorators._raw_dict[name].args[0] # type: ignore[attr-defined]
541
542 log.warning(
543 "The taskflow decorator '%s' has been already registered (by %s).",
544 name,
545 other_name,
546 )
547 return
548
549 self._taskflow_decorators[name] = functools.partial(import_string, decorator_class_name)
550
551 @property
552 def providers(self) -> dict[str, ProviderInfo]:
553 """Returns information about available providers."""
554 self.initialize_providers_list()
555 return self._provider_dict
556
557 @property
558 def hooks(self) -> MutableMapping[str, HookInfo | None]:
559 """
560 Return dictionary of connection_type-to-hook mapping.
561
562 Note that the dict can contain None values if a hook discovered cannot be imported!
563 """
564 self.initialize_providers_hooks()
565 return self._hooks_lazy_dict
566
567 @property
568 def taskflow_decorators(self) -> dict[str, TaskDecorator]:
569 self.initialize_providers_taskflow_decorator()
570 return self._taskflow_decorators # type: ignore[return-value]
571
572 @property
573 def filesystem_module_names(self) -> list[str]:
574 self.initialize_providers_filesystems()
575 return sorted(self._fs_set)
576
577 @property
578 def asset_factories(self) -> dict[str, Callable[..., Asset]]:
579 self.initialize_providers_asset_uri_resources()
580 return self._asset_factories
581
582 @property
583 def asset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]:
584 self.initialize_providers_asset_uri_resources()
585 return self._asset_uri_handlers
586
587 @property
588 def asset_to_openlineage_converters(
589 self,
590 ) -> dict[str, Callable]:
591 self.initialize_providers_asset_uri_resources()
592 return self._asset_to_openlineage_converters
593
594 @property
595 def plugins(self) -> list[PluginInfo]:
596 """Returns information about plugins available in providers."""
597 self.initialize_providers_plugins()
598 return sorted(self._plugins_set, key=lambda x: x.plugin_class)
599
600 def _cleanup(self):
601 self._initialized_cache.clear()
602 self._provider_dict.clear()
603 self._fs_set.clear()
604 self._taskflow_decorators.clear()
605 self._hook_provider_dict.clear()
606 self._hooks_lazy_dict.clear()
607 self._plugins_set.clear()
608 self._asset_uri_handlers.clear()
609 self._asset_factories.clear()
610 self._asset_to_openlineage_converters.clear()
611
612 self._initialized = False
613 self._initialization_stack_trace = None