1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Manages all providers."""
19
20from __future__ import annotations
21
22import contextlib
23import functools
24import inspect
25import json
26import logging
27import traceback
28import warnings
29from collections.abc import Callable, MutableMapping
30from dataclasses import dataclass
31from functools import wraps
32from importlib.resources import files as resource_files
33from time import perf_counter
34from typing import TYPE_CHECKING, Any, NamedTuple, ParamSpec, TypeVar
35
36from packaging.utils import canonicalize_name
37
38from airflow._shared.module_loading import import_string
39from airflow.exceptions import AirflowOptionalProviderFeatureException
40from airflow.utils.entry_points import entry_points_with_dist
41from airflow.utils.log.logging_mixin import LoggingMixin
42from airflow.utils.singleton import Singleton
43
44log = logging.getLogger(__name__)
45
46
47PS = ParamSpec("PS")
48RT = TypeVar("RT")
49
50MIN_PROVIDER_VERSIONS = {
51 "apache-airflow-providers-celery": "2.1.0",
52}
53
54
55def _ensure_prefix_for_placeholders(field_behaviors: dict[str, Any], conn_type: str):
56 """
57 Verify the correct placeholder prefix.
58
59 If the given field_behaviors dict contains a placeholder's node, and there
60 are placeholders for extra fields (i.e. anything other than the built-in conn
61 attrs), and if those extra fields are unprefixed, then add the prefix.
62
63 The reason we need to do this is, all custom conn fields live in the same dictionary,
64 so we need to namespace them with a prefix internally. But for user convenience,
65 and consistency between the `get_ui_field_behaviour` method and the extra dict itself,
66 we allow users to supply the unprefixed name.
67 """
68 conn_attrs = {"host", "schema", "login", "password", "port", "extra"}
69
70 def ensure_prefix(field):
71 if field not in conn_attrs and not field.startswith("extra__"):
72 return f"extra__{conn_type}__{field}"
73 return field
74
75 if "placeholders" in field_behaviors:
76 placeholders = field_behaviors["placeholders"]
77 field_behaviors["placeholders"] = {ensure_prefix(k): v for k, v in placeholders.items()}
78
79 return field_behaviors
80
81
82if TYPE_CHECKING:
83 from urllib.parse import SplitResult
84
85 from airflow.sdk import BaseHook
86 from airflow.sdk.bases.decorator import TaskDecorator
87 from airflow.sdk.definitions.asset import Asset
88
89
90class LazyDictWithCache(MutableMapping):
91 """
92 Lazy-loaded cached dictionary.
93
94 Dictionary, which in case you set callable, executes the passed callable with `key` attribute
95 at first use - and returns and caches the result.
96 """
97
98 __slots__ = ["_resolved", "_raw_dict"]
99
100 def __init__(self, *args, **kw):
101 self._resolved = set()
102 self._raw_dict = dict(*args, **kw)
103
104 def __setitem__(self, key, value):
105 self._raw_dict.__setitem__(key, value)
106
107 def __getitem__(self, key):
108 value = self._raw_dict.__getitem__(key)
109 if key not in self._resolved and callable(value):
110 # exchange callable with result of calling it -- but only once! allow resolver to return a
111 # callable itself
112 value = value()
113 self._resolved.add(key)
114 self._raw_dict.__setitem__(key, value)
115 return value
116
117 def __delitem__(self, key):
118 with contextlib.suppress(KeyError):
119 self._resolved.remove(key)
120 self._raw_dict.__delitem__(key)
121
122 def __iter__(self):
123 return iter(self._raw_dict)
124
125 def __len__(self):
126 return len(self._raw_dict)
127
128 def __contains__(self, key):
129 return key in self._raw_dict
130
131 def clear(self):
132 self._resolved.clear()
133 self._raw_dict.clear()
134
135
136def _read_schema_from_resources_or_local_file(filename: str) -> dict:
137 try:
138 with resource_files("airflow").joinpath(filename).open("rb") as f:
139 schema = json.load(f)
140 except (TypeError, FileNotFoundError):
141 import pathlib
142
143 with (pathlib.Path(__file__).parent / filename).open("rb") as f:
144 schema = json.load(f)
145 return schema
146
147
148def _create_provider_info_schema_validator():
149 """Create JSON schema validator from the provider_info.schema.json."""
150 import jsonschema
151
152 schema = _read_schema_from_resources_or_local_file("provider_info.schema.json")
153 cls = jsonschema.validators.validator_for(schema)
154 validator = cls(schema)
155 return validator
156
157
158def _create_customized_form_field_behaviours_schema_validator():
159 """Create JSON schema validator from the customized_form_field_behaviours.schema.json."""
160 import jsonschema
161
162 schema = _read_schema_from_resources_or_local_file("customized_form_field_behaviours.schema.json")
163 cls = jsonschema.validators.validator_for(schema)
164 validator = cls(schema)
165 return validator
166
167
168def _check_builtin_provider_prefix(provider_package: str, class_name: str) -> bool:
169 if provider_package.startswith("apache-airflow"):
170 provider_path = provider_package[len("apache-") :].replace("-", ".")
171 if not class_name.startswith(provider_path):
172 log.warning(
173 "Coherence check failed when importing '%s' from '%s' package. It should start with '%s'",
174 class_name,
175 provider_package,
176 provider_path,
177 )
178 return False
179 return True
180
181
182@dataclass
183class ProviderInfo:
184 """
185 Provider information.
186
187 :param version: version string
188 :param data: dictionary with information about the provider
189 """
190
191 version: str
192 data: dict
193
194
195class HookClassProvider(NamedTuple):
196 """Hook class and Provider it comes from."""
197
198 hook_class_name: str
199 package_name: str
200
201
202class DialectInfo(NamedTuple):
203 """Dialect class and Provider it comes from."""
204
205 name: str
206 dialect_class_name: str
207 provider_name: str
208
209
210class TriggerInfo(NamedTuple):
211 """Trigger class and provider it comes from."""
212
213 trigger_class_name: str
214 package_name: str
215 integration_name: str
216
217
218class NotificationInfo(NamedTuple):
219 """Notification class and provider it comes from."""
220
221 notification_class_name: str
222 package_name: str
223
224
225class PluginInfo(NamedTuple):
226 """Plugin class, name and provider it comes from."""
227
228 name: str
229 plugin_class: str
230 provider_name: str
231
232
233class HookInfo(NamedTuple):
234 """Hook information."""
235
236 hook_class_name: str
237 connection_id_attribute_name: str
238 package_name: str
239 hook_name: str
240 connection_type: str
241 connection_testable: bool
242 dialects: list[str] = []
243
244
245class ConnectionFormWidgetInfo(NamedTuple):
246 """Connection Form Widget information."""
247
248 hook_class_name: str
249 package_name: str
250 field: Any
251 field_name: str
252 is_sensitive: bool
253
254
255def log_optional_feature_disabled(class_name, e, provider_package):
256 """Log optional feature disabled."""
257 log.debug(
258 "Optional feature disabled on exception when importing '%s' from '%s' package",
259 class_name,
260 provider_package,
261 exc_info=e,
262 )
263 log.info(
264 "Optional provider feature disabled when importing '%s' from '%s' package",
265 class_name,
266 provider_package,
267 )
268
269
270def log_import_warning(class_name, e, provider_package):
271 """Log import warning."""
272 log.warning(
273 "Exception when importing '%s' from '%s' package",
274 class_name,
275 provider_package,
276 exc_info=e,
277 )
278
279
280# This is a temporary measure until all community providers will add AirflowOptionalProviderFeatureException
281# where they have optional features. We are going to add tests in our CI to catch all such cases and will
282# fix them, but until now all "known unhandled optional feature errors" from community providers
283# should be added here
284KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS = [("apache-airflow-providers-google", "No module named 'paramiko'")]
285
286
287def _correctness_check(provider_package: str, class_name: str, provider_info: ProviderInfo) -> Any:
288 """
289 Perform coherence check on provider classes.
290
291 For apache-airflow providers - it checks if it starts with appropriate package. For all providers
292 it tries to import the provider - checking that there are no exceptions during importing.
293 It logs appropriate warning in case it detects any problems.
294
295 :param provider_package: name of the provider package
296 :param class_name: name of the class to import
297
298 :return the class if the class is OK, None otherwise.
299 """
300 if not _check_builtin_provider_prefix(provider_package, class_name):
301 return None
302 try:
303 imported_class = import_string(class_name)
304 except AirflowOptionalProviderFeatureException as e:
305 # When the provider class raises AirflowOptionalProviderFeatureException
306 # this is an expected case when only some classes in provider are
307 # available. We just log debug level here and print info message in logs so that
308 # the user is aware of it
309 log_optional_feature_disabled(class_name, e, provider_package)
310 return None
311 except ImportError as e:
312 if "No module named 'airflow.providers." in e.msg:
313 # handle cases where another provider is missing. This can only happen if
314 # there is an optional feature, so we log debug and print information about it
315 log_optional_feature_disabled(class_name, e, provider_package)
316 return None
317 for known_error in KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS:
318 # Until we convert all providers to use AirflowOptionalProviderFeatureException
319 # we assume any problem with importing another "provider" is because this is an
320 # optional feature, so we log debug and print information about it
321 if known_error[0] == provider_package and known_error[1] in e.msg:
322 log_optional_feature_disabled(class_name, e, provider_package)
323 return None
324 # But when we have no idea - we print warning to logs
325 log_import_warning(class_name, e, provider_package)
326 return None
327 except Exception as e:
328 log_import_warning(class_name, e, provider_package)
329 return None
330 return imported_class
331
332
333# We want to have better control over initialization of parameters and be able to debug and test it
334# So we add our own decorator
335def provider_info_cache(cache_name: str) -> Callable[[Callable[PS, None]], Callable[PS, None]]:
336 """
337 Decorate and cache provider info.
338
339 Decorator factory that create decorator that caches initialization of provider's parameters
340 :param cache_name: Name of the cache
341 """
342
343 def provider_info_cache_decorator(func: Callable[PS, None]) -> Callable[PS, None]:
344 @wraps(func)
345 def wrapped_function(*args: PS.args, **kwargs: PS.kwargs) -> None:
346 providers_manager_instance = args[0]
347 if TYPE_CHECKING:
348 assert isinstance(providers_manager_instance, ProvidersManager)
349
350 if cache_name in providers_manager_instance._initialized_cache:
351 return
352 start_time = perf_counter()
353 log.debug("Initializing Providers Manager[%s]", cache_name)
354 func(*args, **kwargs)
355 providers_manager_instance._initialized_cache[cache_name] = True
356 log.debug(
357 "Initialization of Providers Manager[%s] took %.2f seconds",
358 cache_name,
359 perf_counter() - start_time,
360 )
361
362 return wrapped_function
363
364 return provider_info_cache_decorator
365
366
367class ProvidersManager(LoggingMixin, metaclass=Singleton):
368 """
369 Manages all provider distributions.
370
371 This is a Singleton class. The first time it is
372 instantiated, it discovers all available providers in installed packages.
373 """
374
375 resource_version = "0"
376 _initialized: bool = False
377 _initialization_stack_trace = None
378
379 @staticmethod
380 def initialized() -> bool:
381 return ProvidersManager._initialized
382
383 @staticmethod
384 def initialization_stack_trace() -> str | None:
385 return ProvidersManager._initialization_stack_trace
386
387 def __init__(self):
388 """Initialize the manager."""
389 super().__init__()
390 ProvidersManager._initialized = True
391 ProvidersManager._initialization_stack_trace = "".join(traceback.format_stack(inspect.currentframe()))
392 self._initialized_cache: dict[str, bool] = {}
393 # Keeps dict of providers keyed by module name
394 self._provider_dict: dict[str, ProviderInfo] = {}
395 self._fs_set: set[str] = set()
396 self._asset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {}
397 self._asset_factories: dict[str, Callable[..., Asset]] = {}
398 self._asset_to_openlineage_converters: dict[str, Callable] = {}
399 self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache()
400 # keeps mapping between connection_types and hook class, package they come from
401 self._hook_provider_dict: dict[str, HookClassProvider] = {}
402 self._dialect_provider_dict: dict[str, DialectInfo] = {}
403 # Keeps dict of hooks keyed by connection type. They are lazy evaluated at access time
404 self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] = LazyDictWithCache()
405 # Keeps methods that should be used to add custom widgets tuple of keyed by name of the extra field
406 self._connection_form_widgets: dict[str, ConnectionFormWidgetInfo] = {}
407 # Customizations for javascript fields are kept here
408 self._field_behaviours: dict[str, dict] = {}
409 self._extra_link_class_name_set: set[str] = set()
410 self._logging_class_name_set: set[str] = set()
411 self._auth_manager_class_name_set: set[str] = set()
412 self._secrets_backend_class_name_set: set[str] = set()
413 self._executor_class_name_set: set[str] = set()
414 self._queue_class_name_set: set[str] = set()
415 self._provider_configs: dict[str, dict[str, Any]] = {}
416 self._trigger_info_set: set[TriggerInfo] = set()
417 self._notification_info_set: set[NotificationInfo] = set()
418 self._provider_schema_validator = _create_provider_info_schema_validator()
419 self._customized_form_fields_schema_validator = (
420 _create_customized_form_field_behaviours_schema_validator()
421 )
422 # Set of plugins contained in providers
423 self._plugins_set: set[PluginInfo] = set()
424 self._init_airflow_core_hooks()
425
426 def _init_airflow_core_hooks(self):
427 """Initialize the hooks dict with default hooks from Airflow core."""
428 core_dummy_hooks = {
429 "generic": "Generic",
430 "email": "Email",
431 }
432 for key, display in core_dummy_hooks.items():
433 self._hooks_lazy_dict[key] = HookInfo(
434 hook_class_name=None,
435 connection_id_attribute_name=None,
436 package_name=None,
437 hook_name=display,
438 connection_type=None,
439 connection_testable=False,
440 )
441 for conn_type, class_name in (
442 ("fs", "airflow.providers.standard.hooks.filesystem.FSHook"),
443 ("package_index", "airflow.providers.standard.hooks.package_index.PackageIndexHook"),
444 ):
445 self._hooks_lazy_dict[conn_type] = functools.partial(
446 self._import_hook,
447 connection_type=None,
448 package_name="apache-airflow-providers-standard",
449 hook_class_name=class_name,
450 provider_info=None,
451 )
452
453 @provider_info_cache("list")
454 def initialize_providers_list(self):
455 """Lazy initialization of providers list."""
456 # Local source folders are loaded first. They should take precedence over the package ones for
457 # Development purpose. In production provider.yaml files are not present in the 'airflow" directory
458 # So there is no risk we are going to override package provider accidentally. This can only happen
459 # in case of local development
460 self._discover_all_providers_from_packages()
461 self._verify_all_providers_all_compatible()
462 self._provider_dict = dict(sorted(self._provider_dict.items()))
463
464 def _verify_all_providers_all_compatible(self):
465 from packaging import version as packaging_version
466
467 for provider_id, info in self._provider_dict.items():
468 min_version = MIN_PROVIDER_VERSIONS.get(provider_id)
469 if min_version:
470 if packaging_version.parse(min_version) > packaging_version.parse(info.version):
471 log.warning(
472 "The package %s is not compatible with this version of Airflow. "
473 "The package has version %s but the minimum supported version "
474 "of the package is %s",
475 provider_id,
476 info.version,
477 min_version,
478 )
479
480 @provider_info_cache("hooks")
481 def initialize_providers_hooks(self):
482 """Lazy initialization of providers hooks."""
483 self._init_airflow_core_hooks()
484 self.initialize_providers_list()
485 self._discover_hooks()
486 self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items()))
487
488 @provider_info_cache("filesystems")
489 def initialize_providers_filesystems(self):
490 """Lazy initialization of providers filesystems."""
491 self.initialize_providers_list()
492 self._discover_filesystems()
493
494 @provider_info_cache("asset_uris")
495 def initialize_providers_asset_uri_resources(self):
496 """Lazy initialization of provider asset URI handlers, factories, converters etc."""
497 self.initialize_providers_list()
498 self._discover_asset_uri_resources()
499
500 @provider_info_cache("hook_lineage_writers")
501 @provider_info_cache("taskflow_decorators")
502 def initialize_providers_taskflow_decorator(self):
503 """Lazy initialization of providers hooks."""
504 self.initialize_providers_list()
505 self._discover_taskflow_decorators()
506
507 @provider_info_cache("extra_links")
508 def initialize_providers_extra_links(self):
509 """Lazy initialization of providers extra links."""
510 self.initialize_providers_list()
511 self._discover_extra_links()
512
513 @provider_info_cache("logging")
514 def initialize_providers_logging(self):
515 """Lazy initialization of providers logging information."""
516 self.initialize_providers_list()
517 self._discover_logging()
518
519 @provider_info_cache("secrets_backends")
520 def initialize_providers_secrets_backends(self):
521 """Lazy initialization of providers secrets_backends information."""
522 self.initialize_providers_list()
523 self._discover_secrets_backends()
524
525 @provider_info_cache("executors")
526 def initialize_providers_executors(self):
527 """Lazy initialization of providers executors information."""
528 self.initialize_providers_list()
529 self._discover_executors()
530
531 @provider_info_cache("queues")
532 def initialize_providers_queues(self):
533 """Lazy initialization of providers queue information."""
534 self.initialize_providers_list()
535 self._discover_queues()
536
537 @provider_info_cache("notifications")
538 def initialize_providers_notifications(self):
539 """Lazy initialization of providers notifications information."""
540 self.initialize_providers_list()
541 self._discover_notifications()
542
543 @provider_info_cache("auth_managers")
544 def initialize_providers_auth_managers(self):
545 """Lazy initialization of providers notifications information."""
546 self.initialize_providers_list()
547 self._discover_auth_managers()
548
549 @provider_info_cache("config")
550 def initialize_providers_configuration(self):
551 """Lazy initialization of providers configuration information."""
552 self._initialize_providers_configuration()
553
554 def _initialize_providers_configuration(self):
555 """
556 Initialize providers configuration information.
557
558 Should be used if we do not want to trigger caching for ``initialize_providers_configuration`` method.
559 In some cases we might want to make sure that the configuration is initialized, but we do not want
560 to cache the initialization method - for example when we just want to write configuration with
561 providers, but it is used in the context where no providers are loaded yet we will eventually
562 restore the original configuration and we want the subsequent ``initialize_providers_configuration``
563 method to be run in order to load the configuration for providers again.
564 """
565 self.initialize_providers_list()
566 self._discover_config()
567 # Now update conf with the new provider configuration from providers
568 from airflow.configuration import conf
569
570 conf.load_providers_configuration()
571
572 @provider_info_cache("plugins")
573 def initialize_providers_plugins(self):
574 self.initialize_providers_list()
575 self._discover_plugins()
576
577 def _discover_all_providers_from_packages(self) -> None:
578 """
579 Discover all providers by scanning packages installed.
580
581 The list of providers should be returned via the 'apache_airflow_provider'
582 entrypoint as a dictionary conforming to the 'airflow/provider_info.schema.json'
583 schema. Note that the schema is different at runtime than provider.yaml.schema.json.
584 The development version of provider schema is more strict and changes together with
585 the code. The runtime version is more relaxed (allows for additional properties)
586 and verifies only the subset of fields that are needed at runtime.
587 """
588 for entry_point, dist in entry_points_with_dist("apache_airflow_provider"):
589 if not dist.metadata:
590 continue
591 package_name = canonicalize_name(dist.metadata["name"])
592 if package_name in self._provider_dict:
593 continue
594 log.debug("Loading %s from package %s", entry_point, package_name)
595 version = dist.version
596 provider_info = entry_point.load()()
597 self._provider_schema_validator.validate(provider_info)
598 provider_info_package_name = provider_info["package-name"]
599 if package_name != provider_info_package_name:
600 raise ValueError(
601 f"The package '{package_name}' from packaging information "
602 f"{provider_info_package_name} do not match. Please make sure they are aligned"
603 )
604
605 # issue-59576: Retrieve the project.urls.documentation from dist.metadata
606 project_urls = dist.metadata.get_all("Project-URL")
607 documentation_url: str | None = None
608
609 if project_urls:
610 for entry in project_urls:
611 if "," in entry:
612 name, url = entry.split(",")
613 if name.strip().lower() == "documentation":
614 documentation_url = url
615 break
616
617 provider_info["documentation-url"] = documentation_url
618
619 if package_name not in self._provider_dict:
620 self._provider_dict[package_name] = ProviderInfo(version, provider_info)
621 else:
622 log.warning(
623 "The provider for package '%s' could not be registered from because providers for that "
624 "package name have already been registered",
625 package_name,
626 )
627
628 def _discover_hooks_from_connection_types(
629 self,
630 hook_class_names_registered: set[str],
631 already_registered_warning_connection_types: set[str],
632 package_name: str,
633 provider: ProviderInfo,
634 ):
635 """
636 Discover hooks from the "connection-types" property.
637
638 This is new, better method that replaces discovery from hook-class-names as it
639 allows to lazy import individual Hook classes when they are accessed.
640 The "connection-types" keeps information about both - connection type and class
641 name so we can discover all connection-types without importing the classes.
642 :param hook_class_names_registered: set of registered hook class names for this provider
643 :param already_registered_warning_connection_types: set of connections for which warning should be
644 printed in logs as they were already registered before
645 :param package_name:
646 :param provider:
647 :return:
648 """
649 provider_uses_connection_types = False
650 connection_types = provider.data.get("connection-types")
651 if connection_types:
652 for connection_type_dict in connection_types:
653 connection_type = connection_type_dict["connection-type"]
654 hook_class_name = connection_type_dict["hook-class-name"]
655 hook_class_names_registered.add(hook_class_name)
656 already_registered = self._hook_provider_dict.get(connection_type)
657 if already_registered:
658 if already_registered.package_name != package_name:
659 already_registered_warning_connection_types.add(connection_type)
660 else:
661 log.warning(
662 "The connection type '%s' is already registered in the"
663 " package '%s' with different class names: '%s' and '%s'. ",
664 connection_type,
665 package_name,
666 already_registered.hook_class_name,
667 hook_class_name,
668 )
669 else:
670 self._hook_provider_dict[connection_type] = HookClassProvider(
671 hook_class_name=hook_class_name, package_name=package_name
672 )
673 # Defer importing hook to access time by setting import hook method as dict value
674 self._hooks_lazy_dict[connection_type] = functools.partial(
675 self._import_hook,
676 connection_type=connection_type,
677 provider_info=provider,
678 )
679 provider_uses_connection_types = True
680 return provider_uses_connection_types
681
682 def _discover_hooks_from_hook_class_names(
683 self,
684 hook_class_names_registered: set[str],
685 already_registered_warning_connection_types: set[str],
686 package_name: str,
687 provider: ProviderInfo,
688 provider_uses_connection_types: bool,
689 ):
690 """
691 Discover hooks from "hook-class-names' property.
692
693 This property is deprecated but we should support it in Airflow 2.
694 The hook-class-names array contained just Hook names without connection type,
695 therefore we need to import all those classes immediately to know which connection types
696 are supported. This makes it impossible to selectively only import those hooks that are used.
697 :param already_registered_warning_connection_types: list of connection hooks that we should warn
698 about when finished discovery
699 :param package_name: name of the provider package
700 :param provider: class that keeps information about version and details of the provider
701 :param provider_uses_connection_types: determines whether the provider uses "connection-types" new
702 form of passing connection types
703 :return:
704 """
705 hook_class_names = provider.data.get("hook-class-names")
706 if hook_class_names:
707 for hook_class_name in hook_class_names:
708 if hook_class_name in hook_class_names_registered:
709 # Silently ignore the hook class - it's already marked for lazy-import by
710 # connection-types discovery
711 continue
712 hook_info = self._import_hook(
713 connection_type=None,
714 provider_info=provider,
715 hook_class_name=hook_class_name,
716 package_name=package_name,
717 )
718 if not hook_info:
719 # Problem why importing class - we ignore it. Log is written at import time
720 continue
721 already_registered = self._hook_provider_dict.get(hook_info.connection_type)
722 if already_registered:
723 if already_registered.package_name != package_name:
724 already_registered_warning_connection_types.add(hook_info.connection_type)
725 else:
726 if already_registered.hook_class_name != hook_class_name:
727 log.warning(
728 "The hook connection type '%s' is registered twice in the"
729 " package '%s' with different class names: '%s' and '%s'. "
730 " Please fix it!",
731 hook_info.connection_type,
732 package_name,
733 already_registered.hook_class_name,
734 hook_class_name,
735 )
736 else:
737 self._hook_provider_dict[hook_info.connection_type] = HookClassProvider(
738 hook_class_name=hook_class_name, package_name=package_name
739 )
740 self._hooks_lazy_dict[hook_info.connection_type] = hook_info
741
742 if not provider_uses_connection_types:
743 warnings.warn(
744 f"The provider {package_name} uses `hook-class-names` "
745 "property in provider-info and has no `connection-types` one. "
746 "The 'hook-class-names' property has been deprecated in favour "
747 "of 'connection-types' in Airflow 2.2. Use **both** in case you want to "
748 "have backwards compatibility with Airflow < 2.2",
749 DeprecationWarning,
750 stacklevel=1,
751 )
752 for already_registered_connection_type in already_registered_warning_connection_types:
753 log.warning(
754 "The connection_type '%s' has been already registered by provider '%s.'",
755 already_registered_connection_type,
756 self._hook_provider_dict[already_registered_connection_type].package_name,
757 )
758
759 def _discover_hooks(self) -> None:
760 """Retrieve all connections defined in the providers via Hooks."""
761 for package_name, provider in self._provider_dict.items():
762 duplicated_connection_types: set[str] = set()
763 hook_class_names_registered: set[str] = set()
764 self._discover_provider_dialects(package_name, provider)
765 provider_uses_connection_types = self._discover_hooks_from_connection_types(
766 hook_class_names_registered, duplicated_connection_types, package_name, provider
767 )
768 self._discover_hooks_from_hook_class_names(
769 hook_class_names_registered,
770 duplicated_connection_types,
771 package_name,
772 provider,
773 provider_uses_connection_types,
774 )
775 self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items()))
776
777 def _discover_provider_dialects(self, provider_name: str, provider: ProviderInfo):
778 dialects = provider.data.get("dialects", [])
779 if dialects:
780 self._dialect_provider_dict.update(
781 {
782 item["dialect-type"]: DialectInfo(
783 name=item["dialect-type"],
784 dialect_class_name=item["dialect-class-name"],
785 provider_name=provider_name,
786 )
787 for item in dialects
788 }
789 )
790
791 @provider_info_cache("import_all_hooks")
792 def _import_info_from_all_hooks(self):
793 """Force-import all hooks and initialize the connections/fields."""
794 # Retrieve all hooks to make sure that all of them are imported
795 _ = list(self._hooks_lazy_dict.values())
796 self._field_behaviours = dict(sorted(self._field_behaviours.items()))
797
798 # Widgets for connection forms are currently used in two places:
799 # 1. In the UI Connections, expected same order that it defined in Hook.
800 # 2. cli command - `airflow providers widgets` and expected that it in alphabetical order.
801 # It is not possible to recover original ordering after sorting,
802 # that the main reason why original sorting moved to cli part:
803 # self._connection_form_widgets = dict(sorted(self._connection_form_widgets.items()))
804
805 def _discover_filesystems(self) -> None:
806 """Retrieve all filesystems defined in the providers."""
807 for provider_package, provider in self._provider_dict.items():
808 for fs_module_name in provider.data.get("filesystems", []):
809 if _correctness_check(provider_package, f"{fs_module_name}.get_fs", provider):
810 self._fs_set.add(fs_module_name)
811 self._fs_set = set(sorted(self._fs_set))
812
813 def _discover_asset_uri_resources(self) -> None:
814 """Discovers and registers asset URI handlers, factories, and converters for all providers."""
815 from airflow.sdk.definitions.asset import normalize_noop
816
817 def _safe_register_resource(
818 provider_package_name: str,
819 schemes_list: list[str],
820 resource_path: str | None,
821 resource_registry: dict,
822 default_resource: Any = None,
823 ):
824 """
825 Register a specific resource (handler, factory, or converter) for the given schemes.
826
827 If the resolved resource (either from the path or the default) is valid, it updates
828 the resource registry with the appropriate resource for each scheme.
829 """
830 resource = (
831 _correctness_check(provider_package_name, resource_path, provider)
832 if resource_path is not None
833 else default_resource
834 )
835 if resource:
836 resource_registry.update((scheme, resource) for scheme in schemes_list)
837
838 for provider_name, provider in self._provider_dict.items():
839 for uri_info in provider.data.get("asset-uris", []):
840 if "schemes" not in uri_info or "handler" not in uri_info:
841 continue # Both schemas and handler must be explicitly set, handler can be set to null
842 common_args = {"schemes_list": uri_info["schemes"], "provider_package_name": provider_name}
843 _safe_register_resource(
844 resource_path=uri_info["handler"],
845 resource_registry=self._asset_uri_handlers,
846 default_resource=normalize_noop,
847 **common_args,
848 )
849 _safe_register_resource(
850 resource_path=uri_info.get("factory"),
851 resource_registry=self._asset_factories,
852 **common_args,
853 )
854 _safe_register_resource(
855 resource_path=uri_info.get("to_openlineage_converter"),
856 resource_registry=self._asset_to_openlineage_converters,
857 **common_args,
858 )
859
860 def _discover_taskflow_decorators(self) -> None:
861 for name, info in self._provider_dict.items():
862 for taskflow_decorator in info.data.get("task-decorators", []):
863 self._add_taskflow_decorator(
864 taskflow_decorator["name"], taskflow_decorator["class-name"], name
865 )
866
867 def _add_taskflow_decorator(self, name, decorator_class_name: str, provider_package: str) -> None:
868 if not _check_builtin_provider_prefix(provider_package, decorator_class_name):
869 return
870
871 if name in self._taskflow_decorators:
872 try:
873 existing = self._taskflow_decorators[name]
874 other_name = f"{existing.__module__}.{existing.__name__}"
875 except Exception:
876 # If problem importing, then get the value from the functools.partial
877 other_name = self._taskflow_decorators._raw_dict[name].args[0] # type: ignore[attr-defined]
878
879 log.warning(
880 "The taskflow decorator '%s' has been already registered (by %s).",
881 name,
882 other_name,
883 )
884 return
885
886 self._taskflow_decorators[name] = functools.partial(import_string, decorator_class_name)
887
888 @staticmethod
889 def _get_attr(obj: Any, attr_name: str):
890 """Retrieve attributes of an object, or warn if not found."""
891 if not hasattr(obj, attr_name):
892 log.warning("The object '%s' is missing %s attribute and cannot be registered", obj, attr_name)
893 return None
894 return getattr(obj, attr_name)
895
896 def _import_hook(
897 self,
898 connection_type: str | None,
899 provider_info: ProviderInfo,
900 hook_class_name: str | None = None,
901 package_name: str | None = None,
902 ) -> HookInfo | None:
903 """
904 Import hook and retrieve hook information.
905
906 Either connection_type (for lazy loading) or hook_class_name must be set - but not both).
907 Only needs package_name if hook_class_name is passed (for lazy loading, package_name
908 is retrieved from _connection_type_class_provider_dict together with hook_class_name).
909
910 :param connection_type: type of the connection
911 :param hook_class_name: name of the hook class
912 :param package_name: provider package - only needed in case connection_type is missing
913 : return
914 """
915 if connection_type is None and hook_class_name is None:
916 raise ValueError("Either connection_type or hook_class_name must be set")
917 if connection_type is not None and hook_class_name is not None:
918 raise ValueError(
919 f"Both connection_type ({connection_type} and "
920 f"hook_class_name {hook_class_name} are set. Only one should be set!"
921 )
922 if connection_type is not None:
923 class_provider = self._hook_provider_dict[connection_type]
924 package_name = class_provider.package_name
925 hook_class_name = class_provider.hook_class_name
926 else:
927 if not hook_class_name:
928 raise ValueError("Either connection_type or hook_class_name must be set")
929 if not package_name:
930 raise ValueError(
931 f"Provider package name is not set when hook_class_name ({hook_class_name}) is used"
932 )
933 hook_class: type[BaseHook] | None = _correctness_check(package_name, hook_class_name, provider_info)
934 if hook_class is None:
935 return None
936 try:
937 from wtforms import BooleanField, IntegerField, PasswordField, StringField
938
939 allowed_field_classes = [IntegerField, PasswordField, StringField, BooleanField]
940 module, class_name = hook_class_name.rsplit(".", maxsplit=1)
941 # Do not use attr here. We want to check only direct class fields not those
942 # inherited from parent hook. This way we add form fields only once for the whole
943 # hierarchy and we add it only from the parent hook that provides those!
944 if "get_connection_form_widgets" in hook_class.__dict__:
945 widgets = hook_class.get_connection_form_widgets()
946 if widgets:
947 for widget in widgets.values():
948 if widget.field_class not in allowed_field_classes:
949 log.warning(
950 "The hook_class '%s' uses field of unsupported class '%s'. "
951 "Only '%s' field classes are supported",
952 hook_class_name,
953 widget.field_class,
954 allowed_field_classes,
955 )
956 return None
957 self._add_widgets(package_name, hook_class, widgets)
958 if "get_ui_field_behaviour" in hook_class.__dict__:
959 field_behaviours = hook_class.get_ui_field_behaviour()
960 if field_behaviours:
961 self._add_customized_fields(package_name, hook_class, field_behaviours)
962 except ImportError as e:
963 if e.name in ["flask_appbuilder", "wtforms"]:
964 log.info(
965 "The hook_class '%s' is not fully initialized (UI widgets will be missing), because "
966 "the 'flask_appbuilder' package is not installed, however it is not required for "
967 "Airflow components to work",
968 hook_class_name,
969 )
970 except Exception as e:
971 log.warning(
972 "Exception when importing '%s' from '%s' package: %s",
973 hook_class_name,
974 package_name,
975 e,
976 )
977 return None
978 hook_connection_type = self._get_attr(hook_class, "conn_type")
979 if connection_type:
980 if hook_connection_type != connection_type:
981 log.warning(
982 "Inconsistency! The hook class '%s' declares connection type '%s'"
983 " but it is added by provider '%s' as connection_type '%s' in provider info. "
984 "This should be fixed!",
985 hook_class,
986 hook_connection_type,
987 package_name,
988 connection_type,
989 )
990 connection_type = hook_connection_type
991 connection_id_attribute_name: str = self._get_attr(hook_class, "conn_name_attr")
992 hook_name: str = self._get_attr(hook_class, "hook_name")
993
994 if not connection_type or not connection_id_attribute_name or not hook_name:
995 log.warning(
996 "The hook misses one of the key attributes: "
997 "conn_type: %s, conn_id_attribute_name: %s, hook_name: %s",
998 connection_type,
999 connection_id_attribute_name,
1000 hook_name,
1001 )
1002 return None
1003
1004 return HookInfo(
1005 hook_class_name=hook_class_name,
1006 connection_id_attribute_name=connection_id_attribute_name,
1007 package_name=package_name,
1008 hook_name=hook_name,
1009 connection_type=connection_type,
1010 connection_testable=hasattr(hook_class, "test_connection"),
1011 )
1012
1013 def _add_widgets(self, package_name: str, hook_class: type, widgets: dict[str, Any]):
1014 conn_type = hook_class.conn_type # type: ignore
1015 for field_identifier, field in widgets.items():
1016 if field_identifier.startswith("extra__"):
1017 prefixed_field_name = field_identifier
1018 else:
1019 prefixed_field_name = f"extra__{conn_type}__{field_identifier}"
1020 if prefixed_field_name in self._connection_form_widgets:
1021 log.warning(
1022 "The field %s from class %s has already been added by another provider. Ignoring it.",
1023 field_identifier,
1024 hook_class.__name__,
1025 )
1026 # In case of inherited hooks this might be happening several times
1027 else:
1028 self._connection_form_widgets[prefixed_field_name] = ConnectionFormWidgetInfo(
1029 hook_class.__name__,
1030 package_name,
1031 field,
1032 field_identifier,
1033 hasattr(field.field_class.widget, "input_type")
1034 and field.field_class.widget.input_type == "password",
1035 )
1036
1037 def _add_customized_fields(self, package_name: str, hook_class: type, customized_fields: dict):
1038 try:
1039 connection_type = getattr(hook_class, "conn_type")
1040
1041 self._customized_form_fields_schema_validator.validate(customized_fields)
1042
1043 if connection_type:
1044 customized_fields = _ensure_prefix_for_placeholders(customized_fields, connection_type)
1045
1046 if connection_type in self._field_behaviours:
1047 log.warning(
1048 "The connection_type %s from package %s and class %s has already been added "
1049 "by another provider. Ignoring it.",
1050 connection_type,
1051 package_name,
1052 hook_class.__name__,
1053 )
1054 return
1055 self._field_behaviours[connection_type] = customized_fields
1056 except Exception as e:
1057 log.warning(
1058 "Error when loading customized fields from package '%s' hook class '%s': %s",
1059 package_name,
1060 hook_class.__name__,
1061 e,
1062 )
1063
1064 def _discover_auth_managers(self) -> None:
1065 """Retrieve all auth managers defined in the providers."""
1066 for provider_package, provider in self._provider_dict.items():
1067 if provider.data.get("auth-managers"):
1068 for auth_manager_class_name in provider.data["auth-managers"]:
1069 if _correctness_check(provider_package, auth_manager_class_name, provider):
1070 self._auth_manager_class_name_set.add(auth_manager_class_name)
1071
1072 def _discover_notifications(self) -> None:
1073 """Retrieve all notifications defined in the providers."""
1074 for provider_package, provider in self._provider_dict.items():
1075 if provider.data.get("notifications"):
1076 for notification_class_name in provider.data["notifications"]:
1077 if _correctness_check(provider_package, notification_class_name, provider):
1078 self._notification_info_set.add(notification_class_name)
1079
1080 def _discover_extra_links(self) -> None:
1081 """Retrieve all extra links defined in the providers."""
1082 for provider_package, provider in self._provider_dict.items():
1083 if provider.data.get("extra-links"):
1084 for extra_link_class_name in provider.data["extra-links"]:
1085 if _correctness_check(provider_package, extra_link_class_name, provider):
1086 self._extra_link_class_name_set.add(extra_link_class_name)
1087
1088 def _discover_logging(self) -> None:
1089 """Retrieve all logging defined in the providers."""
1090 for provider_package, provider in self._provider_dict.items():
1091 if provider.data.get("logging"):
1092 for logging_class_name in provider.data["logging"]:
1093 if _correctness_check(provider_package, logging_class_name, provider):
1094 self._logging_class_name_set.add(logging_class_name)
1095
1096 def _discover_secrets_backends(self) -> None:
1097 """Retrieve all secrets backends defined in the providers."""
1098 for provider_package, provider in self._provider_dict.items():
1099 if provider.data.get("secrets-backends"):
1100 for secrets_backends_class_name in provider.data["secrets-backends"]:
1101 if _correctness_check(provider_package, secrets_backends_class_name, provider):
1102 self._secrets_backend_class_name_set.add(secrets_backends_class_name)
1103
1104 def _discover_executors(self) -> None:
1105 """Retrieve all executors defined in the providers."""
1106 for provider_package, provider in self._provider_dict.items():
1107 if provider.data.get("executors"):
1108 for executors_class_name in provider.data["executors"]:
1109 if _correctness_check(provider_package, executors_class_name, provider):
1110 self._executor_class_name_set.add(executors_class_name)
1111
1112 def _discover_queues(self) -> None:
1113 """Retrieve all queues defined in the providers."""
1114 for provider_package, provider in self._provider_dict.items():
1115 if provider.data.get("queues"):
1116 for queue_class_name in provider.data["queues"]:
1117 if _correctness_check(provider_package, queue_class_name, provider):
1118 self._queue_class_name_set.add(queue_class_name)
1119
1120 def _discover_config(self) -> None:
1121 """Retrieve all configs defined in the providers."""
1122 for provider_package, provider in self._provider_dict.items():
1123 if provider.data.get("config"):
1124 self._provider_configs[provider_package] = provider.data.get("config") # type: ignore[assignment]
1125
1126 def _discover_plugins(self) -> None:
1127 """Retrieve all plugins defined in the providers."""
1128 for provider_package, provider in self._provider_dict.items():
1129 for plugin_dict in provider.data.get("plugins", ()):
1130 if not _correctness_check(provider_package, plugin_dict["plugin-class"], provider):
1131 log.warning("Plugin not loaded due to above correctness check problem.")
1132 continue
1133 self._plugins_set.add(
1134 PluginInfo(
1135 name=plugin_dict["name"],
1136 plugin_class=plugin_dict["plugin-class"],
1137 provider_name=provider_package,
1138 )
1139 )
1140
1141 @provider_info_cache("triggers")
1142 def initialize_providers_triggers(self):
1143 """Initialize providers triggers."""
1144 self.initialize_providers_list()
1145 for provider_package, provider in self._provider_dict.items():
1146 for trigger in provider.data.get("triggers", []):
1147 for trigger_class_name in trigger.get("python-modules"):
1148 self._trigger_info_set.add(
1149 TriggerInfo(
1150 package_name=provider_package,
1151 trigger_class_name=trigger_class_name,
1152 integration_name=trigger.get("integration-name", ""),
1153 )
1154 )
1155
1156 @property
1157 def auth_managers(self) -> list[str]:
1158 """Returns information about available providers notifications class."""
1159 self.initialize_providers_auth_managers()
1160 return sorted(self._auth_manager_class_name_set)
1161
1162 @property
1163 def notification(self) -> list[NotificationInfo]:
1164 """Returns information about available providers notifications class."""
1165 self.initialize_providers_notifications()
1166 return sorted(self._notification_info_set)
1167
1168 @property
1169 def trigger(self) -> list[TriggerInfo]:
1170 """Returns information about available providers trigger class."""
1171 self.initialize_providers_triggers()
1172 return sorted(self._trigger_info_set, key=lambda x: x.package_name)
1173
1174 @property
1175 def providers(self) -> dict[str, ProviderInfo]:
1176 """Returns information about available providers."""
1177 self.initialize_providers_list()
1178 return self._provider_dict
1179
1180 @property
1181 def hooks(self) -> MutableMapping[str, HookInfo | None]:
1182 """
1183 Return dictionary of connection_type-to-hook mapping.
1184
1185 Note that the dict can contain None values if a hook discovered cannot be imported!
1186 """
1187 self.initialize_providers_hooks()
1188 # When we return hooks here it will only be used to retrieve hook information
1189 return self._hooks_lazy_dict
1190
1191 @property
1192 def dialects(self) -> MutableMapping[str, DialectInfo]:
1193 """Return dictionary of connection_type-to-dialect mapping."""
1194 self.initialize_providers_hooks()
1195 # When we return dialects here it will only be used to retrieve dialect information
1196 return self._dialect_provider_dict
1197
1198 @property
1199 def plugins(self) -> list[PluginInfo]:
1200 """Returns information about plugins available in providers."""
1201 self.initialize_providers_plugins()
1202 return sorted(self._plugins_set, key=lambda x: x.plugin_class)
1203
1204 @property
1205 def taskflow_decorators(self) -> dict[str, TaskDecorator]:
1206 self.initialize_providers_taskflow_decorator()
1207 return self._taskflow_decorators # type: ignore[return-value]
1208
1209 @property
1210 def extra_links_class_names(self) -> list[str]:
1211 """Returns set of extra link class names."""
1212 self.initialize_providers_extra_links()
1213 return sorted(self._extra_link_class_name_set)
1214
1215 @property
1216 def connection_form_widgets(self) -> dict[str, ConnectionFormWidgetInfo]:
1217 """
1218 Returns widgets for connection forms.
1219
1220 Dictionary keys in the same order that it defined in Hook.
1221 """
1222 self.initialize_providers_hooks()
1223 self._import_info_from_all_hooks()
1224 return self._connection_form_widgets
1225
1226 @property
1227 def field_behaviours(self) -> dict[str, dict]:
1228 """Returns dictionary with field behaviours for connection types."""
1229 self.initialize_providers_hooks()
1230 self._import_info_from_all_hooks()
1231 return self._field_behaviours
1232
1233 @property
1234 def logging_class_names(self) -> list[str]:
1235 """Returns set of log task handlers class names."""
1236 self.initialize_providers_logging()
1237 return sorted(self._logging_class_name_set)
1238
1239 @property
1240 def secrets_backend_class_names(self) -> list[str]:
1241 """Returns set of secret backend class names."""
1242 self.initialize_providers_secrets_backends()
1243 return sorted(self._secrets_backend_class_name_set)
1244
1245 @property
1246 def executor_class_names(self) -> list[str]:
1247 self.initialize_providers_executors()
1248 return sorted(self._executor_class_name_set)
1249
1250 @property
1251 def queue_class_names(self) -> list[str]:
1252 self.initialize_providers_queues()
1253 return sorted(self._queue_class_name_set)
1254
1255 @property
1256 def filesystem_module_names(self) -> list[str]:
1257 self.initialize_providers_filesystems()
1258 return sorted(self._fs_set)
1259
1260 @property
1261 def asset_factories(self) -> dict[str, Callable[..., Asset]]:
1262 self.initialize_providers_asset_uri_resources()
1263 return self._asset_factories
1264
1265 @property
1266 def asset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]:
1267 self.initialize_providers_asset_uri_resources()
1268 return self._asset_uri_handlers
1269
1270 @property
1271 def asset_to_openlineage_converters(
1272 self,
1273 ) -> dict[str, Callable]:
1274 self.initialize_providers_asset_uri_resources()
1275 return self._asset_to_openlineage_converters
1276
1277 @property
1278 def provider_configs(self) -> list[tuple[str, dict[str, Any]]]:
1279 self.initialize_providers_configuration()
1280 return sorted(self._provider_configs.items(), key=lambda x: x[0])
1281
1282 @property
1283 def already_initialized_provider_configs(self) -> list[tuple[str, dict[str, Any]]]:
1284 return sorted(self._provider_configs.items(), key=lambda x: x[0])
1285
1286 def _cleanup(self):
1287 self._initialized_cache.clear()
1288 self._provider_dict.clear()
1289 self._fs_set.clear()
1290 self._taskflow_decorators.clear()
1291 self._hook_provider_dict.clear()
1292 self._dialect_provider_dict.clear()
1293 self._hooks_lazy_dict.clear()
1294 self._connection_form_widgets.clear()
1295 self._field_behaviours.clear()
1296 self._extra_link_class_name_set.clear()
1297 self._logging_class_name_set.clear()
1298 self._auth_manager_class_name_set.clear()
1299 self._secrets_backend_class_name_set.clear()
1300 self._executor_class_name_set.clear()
1301 self._queue_class_name_set.clear()
1302 self._provider_configs.clear()
1303 self._trigger_info_set.clear()
1304 self._notification_info_set.clear()
1305 self._plugins_set.clear()
1306
1307 self._initialized = False
1308 self._initialization_stack_trace = None