Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/providers_manager.py: 29%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Manages all providers."""
20from __future__ import annotations
22import fnmatch
23import functools
24import inspect
25import json
26import logging
27import os
28import sys
29import traceback
30import warnings
31from dataclasses import dataclass
32from functools import wraps
33from time import perf_counter
34from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, TypeVar, cast
36from packaging.utils import canonicalize_name
38from airflow.exceptions import AirflowOptionalProviderFeatureException
39from airflow.hooks.filesystem import FSHook
40from airflow.hooks.package_index import PackageIndexHook
41from airflow.utils import yaml
42from airflow.utils.entry_points import entry_points_with_dist
43from airflow.utils.log.logging_mixin import LoggingMixin
44from airflow.utils.module_loading import import_string
45from airflow.utils.singleton import Singleton
47log = logging.getLogger(__name__)
49if sys.version_info >= (3, 9):
50 from importlib.resources import files as resource_files
51else:
52 from importlib_resources import files as resource_files
54MIN_PROVIDER_VERSIONS = {
55 "apache-airflow-providers-celery": "2.1.0",
56}
59def _ensure_prefix_for_placeholders(field_behaviors: dict[str, Any], conn_type: str):
60 """
61 Verify the correct placeholder prefix.
63 If the given field_behaviors dict contains a placeholder's node, and there
64 are placeholders for extra fields (i.e. anything other than the built-in conn
65 attrs), and if those extra fields are unprefixed, then add the prefix.
67 The reason we need to do this is, all custom conn fields live in the same dictionary,
68 so we need to namespace them with a prefix internally. But for user convenience,
69 and consistency between the `get_ui_field_behaviour` method and the extra dict itself,
70 we allow users to supply the unprefixed name.
71 """
72 conn_attrs = {"host", "schema", "login", "password", "port", "extra"}
74 def ensure_prefix(field):
75 if field not in conn_attrs and not field.startswith("extra__"):
76 return f"extra__{conn_type}__{field}"
77 else:
78 return field
80 if "placeholders" in field_behaviors:
81 placeholders = field_behaviors["placeholders"]
82 field_behaviors["placeholders"] = {ensure_prefix(k): v for k, v in placeholders.items()}
84 return field_behaviors
87if TYPE_CHECKING:
88 from urllib.parse import SplitResult
90 from airflow.decorators.base import TaskDecorator
91 from airflow.hooks.base import BaseHook
92 from airflow.typing_compat import Literal
95class LazyDictWithCache(MutableMapping):
96 """
97 Lazy-loaded cached dictionary.
99 Dictionary, which in case you set callable, executes the passed callable with `key` attribute
100 at first use - and returns and caches the result.
101 """
103 __slots__ = ["_resolved", "_raw_dict"]
105 def __init__(self, *args, **kw):
106 self._resolved = set()
107 self._raw_dict = dict(*args, **kw)
109 def __setitem__(self, key, value):
110 self._raw_dict.__setitem__(key, value)
112 def __getitem__(self, key):
113 value = self._raw_dict.__getitem__(key)
114 if key not in self._resolved and callable(value):
115 # exchange callable with result of calling it -- but only once! allow resolver to return a
116 # callable itself
117 value = value()
118 self._resolved.add(key)
119 self._raw_dict.__setitem__(key, value)
120 return value
122 def __delitem__(self, key):
123 try:
124 self._resolved.remove(key)
125 except KeyError:
126 pass
127 self._raw_dict.__delitem__(key)
129 def __iter__(self):
130 return iter(self._raw_dict)
132 def __len__(self):
133 return len(self._raw_dict)
135 def __contains__(self, key):
136 return key in self._raw_dict
138 def clear(self):
139 self._resolved.clear()
140 self._raw_dict.clear()
143def _read_schema_from_resources_or_local_file(filename: str) -> dict:
144 try:
145 with resource_files("airflow").joinpath(filename).open("rb") as f:
146 schema = json.load(f)
147 except (TypeError, FileNotFoundError):
148 import pathlib
150 with (pathlib.Path(__file__).parent / filename).open("rb") as f:
151 schema = json.load(f)
152 return schema
155def _create_provider_info_schema_validator():
156 """Create JSON schema validator from the provider_info.schema.json."""
157 import jsonschema
159 schema = _read_schema_from_resources_or_local_file("provider_info.schema.json")
160 cls = jsonschema.validators.validator_for(schema)
161 validator = cls(schema)
162 return validator
165def _create_customized_form_field_behaviours_schema_validator():
166 """Create JSON schema validator from the customized_form_field_behaviours.schema.json."""
167 import jsonschema
169 schema = _read_schema_from_resources_or_local_file("customized_form_field_behaviours.schema.json")
170 cls = jsonschema.validators.validator_for(schema)
171 validator = cls(schema)
172 return validator
175def _check_builtin_provider_prefix(provider_package: str, class_name: str) -> bool:
176 if provider_package.startswith("apache-airflow"):
177 provider_path = provider_package[len("apache-") :].replace("-", ".")
178 if not class_name.startswith(provider_path):
179 log.warning(
180 "Coherence check failed when importing '%s' from '%s' package. It should start with '%s'",
181 class_name,
182 provider_package,
183 provider_path,
184 )
185 return False
186 return True
189@dataclass
190class ProviderInfo:
191 """
192 Provider information.
194 :param version: version string
195 :param data: dictionary with information about the provider
196 :param source_or_package: whether the provider is source files or PyPI package. When installed from
197 sources we suppress provider import errors.
198 """
200 version: str
201 data: dict
202 package_or_source: Literal["source"] | Literal["package"]
204 def __post_init__(self):
205 if self.package_or_source not in ("source", "package"):
206 raise ValueError(
207 f"Received {self.package_or_source!r} for `package_or_source`. "
208 "Must be either 'package' or 'source'."
209 )
210 self.is_source = self.package_or_source == "source"
213class HookClassProvider(NamedTuple):
214 """Hook class and Provider it comes from."""
216 hook_class_name: str
217 package_name: str
220class TriggerInfo(NamedTuple):
221 """Trigger class and provider it comes from."""
223 trigger_class_name: str
224 package_name: str
225 integration_name: str
228class NotificationInfo(NamedTuple):
229 """Notification class and provider it comes from."""
231 notification_class_name: str
232 package_name: str
235class PluginInfo(NamedTuple):
236 """Plugin class, name and provider it comes from."""
238 name: str
239 plugin_class: str
240 provider_name: str
243class HookInfo(NamedTuple):
244 """Hook information."""
246 hook_class_name: str
247 connection_id_attribute_name: str
248 package_name: str
249 hook_name: str
250 connection_type: str
251 connection_testable: bool
254class ConnectionFormWidgetInfo(NamedTuple):
255 """Connection Form Widget information."""
257 hook_class_name: str
258 package_name: str
259 field: Any
260 field_name: str
261 is_sensitive: bool
264T = TypeVar("T", bound=Callable)
266logger = logging.getLogger(__name__)
269def log_debug_import_from_sources(class_name, e, provider_package):
270 """Log debug imports from sources."""
271 log.debug(
272 "Optional feature disabled on exception when importing '%s' from '%s' package",
273 class_name,
274 provider_package,
275 exc_info=e,
276 )
279def log_optional_feature_disabled(class_name, e, provider_package):
280 """Log optional feature disabled."""
281 log.debug(
282 "Optional feature disabled on exception when importing '%s' from '%s' package",
283 class_name,
284 provider_package,
285 exc_info=e,
286 )
287 log.info(
288 "Optional provider feature disabled when importing '%s' from '%s' package",
289 class_name,
290 provider_package,
291 )
294def log_import_warning(class_name, e, provider_package):
295 """Log import warning."""
296 log.warning(
297 "Exception when importing '%s' from '%s' package",
298 class_name,
299 provider_package,
300 exc_info=e,
301 )
304# This is a temporary measure until all community providers will add AirflowOptionalProviderFeatureException
305# where they have optional features. We are going to add tests in our CI to catch all such cases and will
306# fix them, but until now all "known unhandled optional feature errors" from community providers
307# should be added here
308KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS = [("apache-airflow-providers-google", "No module named 'paramiko'")]
311def _correctness_check(provider_package: str, class_name: str, provider_info: ProviderInfo) -> Any:
312 """
313 Perform coherence check on provider classes.
315 For apache-airflow providers - it checks if it starts with appropriate package. For all providers
316 it tries to import the provider - checking that there are no exceptions during importing.
317 It logs appropriate warning in case it detects any problems.
319 :param provider_package: name of the provider package
320 :param class_name: name of the class to import
322 :return the class if the class is OK, None otherwise.
323 """
324 if not _check_builtin_provider_prefix(provider_package, class_name):
325 return None
326 try:
327 imported_class = import_string(class_name)
328 except AirflowOptionalProviderFeatureException as e:
329 # When the provider class raises AirflowOptionalProviderFeatureException
330 # this is an expected case when only some classes in provider are
331 # available. We just log debug level here and print info message in logs so that
332 # the user is aware of it
333 log_optional_feature_disabled(class_name, e, provider_package)
334 return None
335 except ImportError as e:
336 if provider_info.is_source:
337 # When we have providers from sources, then we just turn all import logs to debug logs
338 # As this is pretty expected that you have a number of dependencies not installed
339 # (we always have all providers from sources until we split providers to separate repo)
340 log_debug_import_from_sources(class_name, e, provider_package)
341 return None
342 if "No module named 'airflow.providers." in e.msg:
343 # handle cases where another provider is missing. This can only happen if
344 # there is an optional feature, so we log debug and print information about it
345 log_optional_feature_disabled(class_name, e, provider_package)
346 return None
347 for known_error in KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS:
348 # Until we convert all providers to use AirflowOptionalProviderFeatureException
349 # we assume any problem with importing another "provider" is because this is an
350 # optional feature, so we log debug and print information about it
351 if known_error[0] == provider_package and known_error[1] in e.msg:
352 log_optional_feature_disabled(class_name, e, provider_package)
353 return None
354 # But when we have no idea - we print warning to logs
355 log_import_warning(class_name, e, provider_package)
356 return None
357 except Exception as e:
358 log_import_warning(class_name, e, provider_package)
359 return None
360 return imported_class
363# We want to have better control over initialization of parameters and be able to debug and test it
364# So we add our own decorator
365def provider_info_cache(cache_name: str) -> Callable[[T], T]:
366 """
367 Decorate and cache provider info.
369 Decorator factory that create decorator that caches initialization of provider's parameters
370 :param cache_name: Name of the cache
371 """
373 def provider_info_cache_decorator(func: T):
374 @wraps(func)
375 def wrapped_function(*args, **kwargs):
376 providers_manager_instance = args[0]
377 if cache_name in providers_manager_instance._initialized_cache:
378 return
379 start_time = perf_counter()
380 logger.debug("Initializing Providers Manager[%s]", cache_name)
381 func(*args, **kwargs)
382 providers_manager_instance._initialized_cache[cache_name] = True
383 logger.debug(
384 "Initialization of Providers Manager[%s] took %.2f seconds",
385 cache_name,
386 perf_counter() - start_time,
387 )
389 return cast(T, wrapped_function)
391 return provider_info_cache_decorator
394class ProvidersManager(LoggingMixin, metaclass=Singleton):
395 """
396 Manages all provider packages.
398 This is a Singleton class. The first time it is
399 instantiated, it discovers all available providers in installed packages and
400 local source folders (if airflow is run from sources).
401 """
403 resource_version = "0"
404 _initialized: bool = False
405 _initialization_stack_trace = None
407 @staticmethod
408 def initialized() -> bool:
409 return ProvidersManager._initialized
411 @staticmethod
412 def initialization_stack_trace() -> str | None:
413 return ProvidersManager._initialization_stack_trace
415 def __init__(self):
416 """Initialize the manager."""
417 super().__init__()
418 ProvidersManager._initialized = True
419 ProvidersManager._initialization_stack_trace = "".join(traceback.format_stack(inspect.currentframe()))
420 self._initialized_cache: dict[str, bool] = {}
421 # Keeps dict of providers keyed by module name
422 self._provider_dict: dict[str, ProviderInfo] = {}
423 # Keeps dict of hooks keyed by connection type
424 self._hooks_dict: dict[str, HookInfo] = {}
425 self._fs_set: set[str] = set()
426 self._dataset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {}
427 self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() # type: ignore[assignment]
428 # keeps mapping between connection_types and hook class, package they come from
429 self._hook_provider_dict: dict[str, HookClassProvider] = {}
430 # Keeps dict of hooks keyed by connection type. They are lazy evaluated at access time
431 self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] = LazyDictWithCache()
432 # Keeps methods that should be used to add custom widgets tuple of keyed by name of the extra field
433 self._connection_form_widgets: dict[str, ConnectionFormWidgetInfo] = {}
434 # Customizations for javascript fields are kept here
435 self._field_behaviours: dict[str, dict] = {}
436 self._extra_link_class_name_set: set[str] = set()
437 self._logging_class_name_set: set[str] = set()
438 self._auth_manager_class_name_set: set[str] = set()
439 self._secrets_backend_class_name_set: set[str] = set()
440 self._executor_class_name_set: set[str] = set()
441 self._provider_configs: dict[str, dict[str, Any]] = {}
442 self._api_auth_backend_module_names: set[str] = set()
443 self._trigger_info_set: set[TriggerInfo] = set()
444 self._notification_info_set: set[NotificationInfo] = set()
445 self._provider_schema_validator = _create_provider_info_schema_validator()
446 self._customized_form_fields_schema_validator = (
447 _create_customized_form_field_behaviours_schema_validator()
448 )
449 # Set of plugins contained in providers
450 self._plugins_set: set[PluginInfo] = set()
451 self._init_airflow_core_hooks()
453 def _init_airflow_core_hooks(self):
454 """Initialize the hooks dict with default hooks from Airflow core."""
455 core_dummy_hooks = {
456 "generic": "Generic",
457 "email": "Email",
458 }
459 for key, display in core_dummy_hooks.items():
460 self._hooks_lazy_dict[key] = HookInfo(
461 hook_class_name=None,
462 connection_id_attribute_name=None,
463 package_name=None,
464 hook_name=display,
465 connection_type=None,
466 connection_testable=False,
467 )
468 for cls in [FSHook, PackageIndexHook]:
469 package_name = cls.__module__
470 hook_class_name = f"{cls.__module__}.{cls.__name__}"
471 hook_info = self._import_hook(
472 connection_type=None,
473 provider_info=None,
474 hook_class_name=hook_class_name,
475 package_name=package_name,
476 )
477 self._hook_provider_dict[hook_info.connection_type] = HookClassProvider(
478 hook_class_name=hook_class_name, package_name=package_name
479 )
480 self._hooks_lazy_dict[hook_info.connection_type] = hook_info
482 @provider_info_cache("list")
483 def initialize_providers_list(self):
484 """Lazy initialization of providers list."""
485 # Local source folders are loaded first. They should take precedence over the package ones for
486 # Development purpose. In production provider.yaml files are not present in the 'airflow" directory
487 # So there is no risk we are going to override package provider accidentally. This can only happen
488 # in case of local development
489 self._discover_all_airflow_builtin_providers_from_local_sources()
490 self._discover_all_providers_from_packages()
491 self._verify_all_providers_all_compatible()
492 self._provider_dict = dict(sorted(self._provider_dict.items()))
494 def _verify_all_providers_all_compatible(self):
495 from packaging import version as packaging_version
497 for provider_id, info in self._provider_dict.items():
498 min_version = MIN_PROVIDER_VERSIONS.get(provider_id)
499 if min_version:
500 if packaging_version.parse(min_version) > packaging_version.parse(info.version):
501 log.warning(
502 "The package %s is not compatible with this version of Airflow. "
503 "The package has version %s but the minimum supported version "
504 "of the package is %s",
505 provider_id,
506 info.version,
507 min_version,
508 )
510 @provider_info_cache("hooks")
511 def initialize_providers_hooks(self):
512 """Lazy initialization of providers hooks."""
513 self.initialize_providers_list()
514 self._discover_hooks()
515 self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items()))
517 @provider_info_cache("filesystems")
518 def initialize_providers_filesystems(self):
519 """Lazy initialization of providers filesystems."""
520 self.initialize_providers_list()
521 self._discover_filesystems()
523 @provider_info_cache("dataset_uris")
524 def initializa_providers_dataset_uri_handlers(self):
525 """Lazy initialization of provider dataset URI handlers."""
526 self.initialize_providers_list()
527 self._discover_dataset_uri_handlers()
529 @provider_info_cache("taskflow_decorators")
530 def initialize_providers_taskflow_decorator(self):
531 """Lazy initialization of providers hooks."""
532 self.initialize_providers_list()
533 self._discover_taskflow_decorators()
535 @provider_info_cache("extra_links")
536 def initialize_providers_extra_links(self):
537 """Lazy initialization of providers extra links."""
538 self.initialize_providers_list()
539 self._discover_extra_links()
541 @provider_info_cache("logging")
542 def initialize_providers_logging(self):
543 """Lazy initialization of providers logging information."""
544 self.initialize_providers_list()
545 self._discover_logging()
547 @provider_info_cache("secrets_backends")
548 def initialize_providers_secrets_backends(self):
549 """Lazy initialization of providers secrets_backends information."""
550 self.initialize_providers_list()
551 self._discover_secrets_backends()
553 @provider_info_cache("executors")
554 def initialize_providers_executors(self):
555 """Lazy initialization of providers executors information."""
556 self.initialize_providers_list()
557 self._discover_executors()
559 @provider_info_cache("notifications")
560 def initialize_providers_notifications(self):
561 """Lazy initialization of providers notifications information."""
562 self.initialize_providers_list()
563 self._discover_notifications()
565 @provider_info_cache("auth_managers")
566 def initialize_providers_auth_managers(self):
567 """Lazy initialization of providers notifications information."""
568 self.initialize_providers_list()
569 self._discover_auth_managers()
571 @provider_info_cache("config")
572 def initialize_providers_configuration(self):
573 """Lazy initialization of providers configuration information."""
574 self._initialize_providers_configuration()
576 def _initialize_providers_configuration(self):
577 """
578 Initialize providers configuration information.
580 Should be used if we do not want to trigger caching for ``initialize_providers_configuration`` method.
581 In some cases we might want to make sure that the configuration is initialized, but we do not want
582 to cache the initialization method - for example when we just want to write configuration with
583 providers, but it is used in the context where no providers are loaded yet we will eventually
584 restore the original configuration and we want the subsequent ``initialize_providers_configuration``
585 method to be run in order to load the configuration for providers again.
586 """
587 self.initialize_providers_list()
588 self._discover_config()
589 # Now update conf with the new provider configuration from providers
590 from airflow.configuration import conf
592 conf.load_providers_configuration()
594 @provider_info_cache("auth_backends")
595 def initialize_providers_auth_backends(self):
596 """Lazy initialization of providers API auth_backends information."""
597 self.initialize_providers_list()
598 self._discover_auth_backends()
600 @provider_info_cache("plugins")
601 def initialize_providers_plugins(self):
602 self.initialize_providers_list()
603 self._discover_plugins()
605 def _discover_all_providers_from_packages(self) -> None:
606 """
607 Discover all providers by scanning packages installed.
609 The list of providers should be returned via the 'apache_airflow_provider'
610 entrypoint as a dictionary conforming to the 'airflow/provider_info.schema.json'
611 schema. Note that the schema is different at runtime than provider.yaml.schema.json.
612 The development version of provider schema is more strict and changes together with
613 the code. The runtime version is more relaxed (allows for additional properties)
614 and verifies only the subset of fields that are needed at runtime.
615 """
616 for entry_point, dist in entry_points_with_dist("apache_airflow_provider"):
617 package_name = canonicalize_name(dist.metadata["name"])
618 if package_name in self._provider_dict:
619 continue
620 log.debug("Loading %s from package %s", entry_point, package_name)
621 version = dist.version
622 provider_info = entry_point.load()()
623 self._provider_schema_validator.validate(provider_info)
624 provider_info_package_name = provider_info["package-name"]
625 if package_name != provider_info_package_name:
626 raise ValueError(
627 f"The package '{package_name}' from setuptools and "
628 f"{provider_info_package_name} do not match. Please make sure they are aligned"
629 )
630 if package_name not in self._provider_dict:
631 self._provider_dict[package_name] = ProviderInfo(version, provider_info, "package")
632 else:
633 log.warning(
634 "The provider for package '%s' could not be registered from because providers for that "
635 "package name have already been registered",
636 package_name,
637 )
639 def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None:
640 """
641 Find all built-in airflow providers if airflow is run from the local sources.
643 It finds `provider.yaml` files for all such providers and registers the providers using those.
645 This 'provider.yaml' scanning takes precedence over scanning packages installed
646 in case you have both sources and packages installed, the providers will be loaded from
647 the "airflow" sources rather than from the packages.
648 """
649 try:
650 import airflow.providers
651 except ImportError:
652 log.info("You have no providers installed.")
653 return
655 seen = set()
656 for path in airflow.providers.__path__: # type: ignore[attr-defined]
657 try:
658 # The same path can appear in the __path__ twice, under non-normalized paths (ie.
659 # /path/to/repo/airflow/providers and /path/to/repo/./airflow/providers)
660 path = os.path.realpath(path)
661 if path not in seen:
662 seen.add(path)
663 self._add_provider_info_from_local_source_files_on_path(path)
664 except Exception as e:
665 log.warning("Error when loading 'provider.yaml' files from %s airflow sources: %s", path, e)
667 def _add_provider_info_from_local_source_files_on_path(self, path) -> None:
668 """
669 Find all the provider.yaml files in the directory specified.
671 :param path: path where to look for provider.yaml files
672 """
673 root_path = path
674 for folder, subdirs, files in os.walk(path, topdown=True):
675 for filename in fnmatch.filter(files, "provider.yaml"):
676 try:
677 package_name = "apache-airflow-providers" + folder[len(root_path) :].replace(os.sep, "-")
678 self._add_provider_info_from_local_source_file(
679 os.path.join(folder, filename), package_name
680 )
681 subdirs[:] = []
682 except Exception as e:
683 log.warning("Error when loading 'provider.yaml' file from %s %e", folder, e)
685 def _add_provider_info_from_local_source_file(self, path, package_name) -> None:
686 """
687 Parse found provider.yaml file and adds found provider to the dictionary.
689 :param path: full file path of the provider.yaml file
690 :param package_name: name of the package
691 """
692 try:
693 log.debug("Loading %s from %s", package_name, path)
694 with open(path) as provider_yaml_file:
695 provider_info = yaml.safe_load(provider_yaml_file)
696 self._provider_schema_validator.validate(provider_info)
697 version = provider_info["versions"][0]
698 if package_name not in self._provider_dict:
699 self._provider_dict[package_name] = ProviderInfo(version, provider_info, "source")
700 else:
701 log.warning(
702 "The providers for package '%s' could not be registered because providers for that "
703 "package name have already been registered",
704 package_name,
705 )
706 except Exception as e:
707 log.warning("Error when loading '%s'", path, exc_info=e)
709 def _discover_hooks_from_connection_types(
710 self,
711 hook_class_names_registered: set[str],
712 already_registered_warning_connection_types: set[str],
713 package_name: str,
714 provider: ProviderInfo,
715 ):
716 """
717 Discover hooks from the "connection-types" property.
719 This is new, better method that replaces discovery from hook-class-names as it
720 allows to lazy import individual Hook classes when they are accessed.
721 The "connection-types" keeps information about both - connection type and class
722 name so we can discover all connection-types without importing the classes.
723 :param hook_class_names_registered: set of registered hook class names for this provider
724 :param already_registered_warning_connection_types: set of connections for which warning should be
725 printed in logs as they were already registered before
726 :param package_name:
727 :param provider:
728 :return:
729 """
730 provider_uses_connection_types = False
731 connection_types = provider.data.get("connection-types")
732 if connection_types:
733 for connection_type_dict in connection_types:
734 connection_type = connection_type_dict["connection-type"]
735 hook_class_name = connection_type_dict["hook-class-name"]
736 hook_class_names_registered.add(hook_class_name)
737 already_registered = self._hook_provider_dict.get(connection_type)
738 if already_registered:
739 if already_registered.package_name != package_name:
740 already_registered_warning_connection_types.add(connection_type)
741 else:
742 log.warning(
743 "The connection type '%s' is already registered in the"
744 " package '%s' with different class names: '%s' and '%s'. ",
745 connection_type,
746 package_name,
747 already_registered.hook_class_name,
748 hook_class_name,
749 )
750 else:
751 self._hook_provider_dict[connection_type] = HookClassProvider(
752 hook_class_name=hook_class_name, package_name=package_name
753 )
754 # Defer importing hook to access time by setting import hook method as dict value
755 self._hooks_lazy_dict[connection_type] = functools.partial(
756 self._import_hook,
757 connection_type=connection_type,
758 provider_info=provider,
759 )
760 provider_uses_connection_types = True
761 return provider_uses_connection_types
763 def _discover_hooks_from_hook_class_names(
764 self,
765 hook_class_names_registered: set[str],
766 already_registered_warning_connection_types: set[str],
767 package_name: str,
768 provider: ProviderInfo,
769 provider_uses_connection_types: bool,
770 ):
771 """
772 Discover hooks from "hook-class-names' property.
774 This property is deprecated but we should support it in Airflow 2.
775 The hook-class-names array contained just Hook names without connection type,
776 therefore we need to import all those classes immediately to know which connection types
777 are supported. This makes it impossible to selectively only import those hooks that are used.
778 :param already_registered_warning_connection_types: list of connection hooks that we should warn
779 about when finished discovery
780 :param package_name: name of the provider package
781 :param provider: class that keeps information about version and details of the provider
782 :param provider_uses_connection_types: determines whether the provider uses "connection-types" new
783 form of passing connection types
784 :return:
785 """
786 hook_class_names = provider.data.get("hook-class-names")
787 if hook_class_names:
788 for hook_class_name in hook_class_names:
789 if hook_class_name in hook_class_names_registered:
790 # Silently ignore the hook class - it's already marked for lazy-import by
791 # connection-types discovery
792 continue
793 hook_info = self._import_hook(
794 connection_type=None,
795 provider_info=provider,
796 hook_class_name=hook_class_name,
797 package_name=package_name,
798 )
799 if not hook_info:
800 # Problem why importing class - we ignore it. Log is written at import time
801 continue
802 already_registered = self._hook_provider_dict.get(hook_info.connection_type)
803 if already_registered:
804 if already_registered.package_name != package_name:
805 already_registered_warning_connection_types.add(hook_info.connection_type)
806 else:
807 if already_registered.hook_class_name != hook_class_name:
808 log.warning(
809 "The hook connection type '%s' is registered twice in the"
810 " package '%s' with different class names: '%s' and '%s'. "
811 " Please fix it!",
812 hook_info.connection_type,
813 package_name,
814 already_registered.hook_class_name,
815 hook_class_name,
816 )
817 else:
818 self._hook_provider_dict[hook_info.connection_type] = HookClassProvider(
819 hook_class_name=hook_class_name, package_name=package_name
820 )
821 self._hooks_lazy_dict[hook_info.connection_type] = hook_info
823 if not provider_uses_connection_types:
824 warnings.warn(
825 f"The provider {package_name} uses `hook-class-names` "
826 "property in provider-info and has no `connection-types` one. "
827 "The 'hook-class-names' property has been deprecated in favour "
828 "of 'connection-types' in Airflow 2.2. Use **both** in case you want to "
829 "have backwards compatibility with Airflow < 2.2",
830 DeprecationWarning,
831 stacklevel=1,
832 )
833 for already_registered_connection_type in already_registered_warning_connection_types:
834 log.warning(
835 "The connection_type '%s' has been already registered by provider '%s.'",
836 already_registered_connection_type,
837 self._hook_provider_dict[already_registered_connection_type].package_name,
838 )
840 def _discover_hooks(self) -> None:
841 """Retrieve all connections defined in the providers via Hooks."""
842 for package_name, provider in self._provider_dict.items():
843 duplicated_connection_types: set[str] = set()
844 hook_class_names_registered: set[str] = set()
845 provider_uses_connection_types = self._discover_hooks_from_connection_types(
846 hook_class_names_registered, duplicated_connection_types, package_name, provider
847 )
848 self._discover_hooks_from_hook_class_names(
849 hook_class_names_registered,
850 duplicated_connection_types,
851 package_name,
852 provider,
853 provider_uses_connection_types,
854 )
855 self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items()))
857 @provider_info_cache("import_all_hooks")
858 def _import_info_from_all_hooks(self):
859 """Force-import all hooks and initialize the connections/fields."""
860 # Retrieve all hooks to make sure that all of them are imported
861 _ = list(self._hooks_lazy_dict.values())
862 self._field_behaviours = dict(sorted(self._field_behaviours.items()))
864 # Widgets for connection forms are currently used in two places:
865 # 1. In the UI Connections, expected same order that it defined in Hook.
866 # 2. cli command - `airflow providers widgets` and expected that it in alphabetical order.
867 # It is not possible to recover original ordering after sorting,
868 # that the main reason why original sorting moved to cli part:
869 # self._connection_form_widgets = dict(sorted(self._connection_form_widgets.items()))
871 def _discover_filesystems(self) -> None:
872 """Retrieve all filesystems defined in the providers."""
873 for provider_package, provider in self._provider_dict.items():
874 for fs_module_name in provider.data.get("filesystems", []):
875 if _correctness_check(provider_package, f"{fs_module_name}.get_fs", provider):
876 self._fs_set.add(fs_module_name)
877 self._fs_set = set(sorted(self._fs_set))
879 def _discover_dataset_uri_handlers(self) -> None:
880 from airflow.datasets import normalize_noop
882 for provider_package, provider in self._provider_dict.items():
883 for handler_info in provider.data.get("dataset-uris", []):
884 try:
885 schemes = handler_info["schemes"]
886 handler_path = handler_info["handler"]
887 except KeyError:
888 continue
889 if handler_path is None:
890 handler = normalize_noop
891 elif not (handler := _correctness_check(provider_package, handler_path, provider)):
892 continue
893 self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes)
895 def _discover_taskflow_decorators(self) -> None:
896 for name, info in self._provider_dict.items():
897 for taskflow_decorator in info.data.get("task-decorators", []):
898 self._add_taskflow_decorator(
899 taskflow_decorator["name"], taskflow_decorator["class-name"], name
900 )
902 def _add_taskflow_decorator(self, name, decorator_class_name: str, provider_package: str) -> None:
903 if not _check_builtin_provider_prefix(provider_package, decorator_class_name):
904 return
906 if name in self._taskflow_decorators:
907 try:
908 existing = self._taskflow_decorators[name]
909 other_name = f"{existing.__module__}.{existing.__name__}"
910 except Exception:
911 # If problem importing, then get the value from the functools.partial
912 other_name = self._taskflow_decorators._raw_dict[name].args[0] # type: ignore[attr-defined]
914 log.warning(
915 "The taskflow decorator '%s' has been already registered (by %s).",
916 name,
917 other_name,
918 )
919 return
921 self._taskflow_decorators[name] = functools.partial(import_string, decorator_class_name)
923 @staticmethod
924 def _get_attr(obj: Any, attr_name: str):
925 """Retrieve attributes of an object, or warn if not found."""
926 if not hasattr(obj, attr_name):
927 log.warning("The object '%s' is missing %s attribute and cannot be registered", obj, attr_name)
928 return None
929 return getattr(obj, attr_name)
931 def _import_hook(
932 self,
933 connection_type: str | None,
934 provider_info: ProviderInfo,
935 hook_class_name: str | None = None,
936 package_name: str | None = None,
937 ) -> HookInfo | None:
938 """
939 Import hook and retrieve hook information.
941 Either connection_type (for lazy loading) or hook_class_name must be set - but not both).
942 Only needs package_name if hook_class_name is passed (for lazy loading, package_name
943 is retrieved from _connection_type_class_provider_dict together with hook_class_name).
945 :param connection_type: type of the connection
946 :param hook_class_name: name of the hook class
947 :param package_name: provider package - only needed in case connection_type is missing
948 : return
949 """
950 from wtforms import BooleanField, IntegerField, PasswordField, StringField
952 if connection_type is None and hook_class_name is None:
953 raise ValueError("Either connection_type or hook_class_name must be set")
954 if connection_type is not None and hook_class_name is not None:
955 raise ValueError(
956 f"Both connection_type ({connection_type} and "
957 f"hook_class_name {hook_class_name} are set. Only one should be set!"
958 )
959 if connection_type is not None:
960 class_provider = self._hook_provider_dict[connection_type]
961 package_name = class_provider.package_name
962 hook_class_name = class_provider.hook_class_name
963 else:
964 if not hook_class_name:
965 raise ValueError("Either connection_type or hook_class_name must be set")
966 if not package_name:
967 raise ValueError(
968 f"Provider package name is not set when hook_class_name ({hook_class_name}) is used"
969 )
970 allowed_field_classes = [IntegerField, PasswordField, StringField, BooleanField]
971 hook_class: type[BaseHook] | None = _correctness_check(package_name, hook_class_name, provider_info)
972 if hook_class is None:
973 return None
974 try:
975 module, class_name = hook_class_name.rsplit(".", maxsplit=1)
976 # Do not use attr here. We want to check only direct class fields not those
977 # inherited from parent hook. This way we add form fields only once for the whole
978 # hierarchy and we add it only from the parent hook that provides those!
979 if "get_connection_form_widgets" in hook_class.__dict__:
980 widgets = hook_class.get_connection_form_widgets()
982 if widgets:
983 for widget in widgets.values():
984 if widget.field_class not in allowed_field_classes:
985 log.warning(
986 "The hook_class '%s' uses field of unsupported class '%s'. "
987 "Only '%s' field classes are supported",
988 hook_class_name,
989 widget.field_class,
990 allowed_field_classes,
991 )
992 return None
993 self._add_widgets(package_name, hook_class, widgets)
994 if "get_ui_field_behaviour" in hook_class.__dict__:
995 field_behaviours = hook_class.get_ui_field_behaviour()
996 if field_behaviours:
997 self._add_customized_fields(package_name, hook_class, field_behaviours)
998 except ImportError as e:
999 if "No module named 'flask_appbuilder'" in e.msg:
1000 log.warning(
1001 "The hook_class '%s' is not fully initialized (UI widgets will be missing), because "
1002 "the 'flask_appbuilder' package is not installed, however it is not required for "
1003 "Airflow components to work",
1004 hook_class_name,
1005 )
1006 except Exception as e:
1007 log.warning(
1008 "Exception when importing '%s' from '%s' package: %s",
1009 hook_class_name,
1010 package_name,
1011 e,
1012 )
1013 return None
1014 hook_connection_type = self._get_attr(hook_class, "conn_type")
1015 if connection_type:
1016 if hook_connection_type != connection_type:
1017 log.warning(
1018 "Inconsistency! The hook class '%s' declares connection type '%s'"
1019 " but it is added by provider '%s' as connection_type '%s' in provider info. "
1020 "This should be fixed!",
1021 hook_class,
1022 hook_connection_type,
1023 package_name,
1024 connection_type,
1025 )
1026 connection_type = hook_connection_type
1027 connection_id_attribute_name: str = self._get_attr(hook_class, "conn_name_attr")
1028 hook_name: str = self._get_attr(hook_class, "hook_name")
1030 if not connection_type or not connection_id_attribute_name or not hook_name:
1031 log.warning(
1032 "The hook misses one of the key attributes: "
1033 "conn_type: %s, conn_id_attribute_name: %s, hook_name: %s",
1034 connection_type,
1035 connection_id_attribute_name,
1036 hook_name,
1037 )
1038 return None
1040 return HookInfo(
1041 hook_class_name=hook_class_name,
1042 connection_id_attribute_name=connection_id_attribute_name,
1043 package_name=package_name,
1044 hook_name=hook_name,
1045 connection_type=connection_type,
1046 connection_testable=hasattr(hook_class, "test_connection"),
1047 )
1049 def _add_widgets(self, package_name: str, hook_class: type, widgets: dict[str, Any]):
1050 conn_type = hook_class.conn_type # type: ignore
1051 for field_identifier, field in widgets.items():
1052 if field_identifier.startswith("extra__"):
1053 prefixed_field_name = field_identifier
1054 else:
1055 prefixed_field_name = f"extra__{conn_type}__{field_identifier}"
1056 if prefixed_field_name in self._connection_form_widgets:
1057 log.warning(
1058 "The field %s from class %s has already been added by another provider. Ignoring it.",
1059 field_identifier,
1060 hook_class.__name__,
1061 )
1062 # In case of inherited hooks this might be happening several times
1063 else:
1064 self._connection_form_widgets[prefixed_field_name] = ConnectionFormWidgetInfo(
1065 hook_class.__name__,
1066 package_name,
1067 field,
1068 field_identifier,
1069 hasattr(field.field_class.widget, "input_type")
1070 and field.field_class.widget.input_type == "password",
1071 )
1073 def _add_customized_fields(self, package_name: str, hook_class: type, customized_fields: dict):
1074 try:
1075 connection_type = getattr(hook_class, "conn_type")
1077 self._customized_form_fields_schema_validator.validate(customized_fields)
1079 if connection_type:
1080 customized_fields = _ensure_prefix_for_placeholders(customized_fields, connection_type)
1082 if connection_type in self._field_behaviours:
1083 log.warning(
1084 "The connection_type %s from package %s and class %s has already been added "
1085 "by another provider. Ignoring it.",
1086 connection_type,
1087 package_name,
1088 hook_class.__name__,
1089 )
1090 return
1091 self._field_behaviours[connection_type] = customized_fields
1092 except Exception as e:
1093 log.warning(
1094 "Error when loading customized fields from package '%s' hook class '%s': %s",
1095 package_name,
1096 hook_class.__name__,
1097 e,
1098 )
1100 def _discover_auth_managers(self) -> None:
1101 """Retrieve all auth managers defined in the providers."""
1102 for provider_package, provider in self._provider_dict.items():
1103 if provider.data.get("auth-managers"):
1104 for auth_manager_class_name in provider.data["auth-managers"]:
1105 if _correctness_check(provider_package, auth_manager_class_name, provider):
1106 self._auth_manager_class_name_set.add(auth_manager_class_name)
1108 def _discover_notifications(self) -> None:
1109 """Retrieve all notifications defined in the providers."""
1110 for provider_package, provider in self._provider_dict.items():
1111 if provider.data.get("notifications"):
1112 for notification_class_name in provider.data["notifications"]:
1113 if _correctness_check(provider_package, notification_class_name, provider):
1114 self._notification_info_set.add(notification_class_name)
1116 def _discover_extra_links(self) -> None:
1117 """Retrieve all extra links defined in the providers."""
1118 for provider_package, provider in self._provider_dict.items():
1119 if provider.data.get("extra-links"):
1120 for extra_link_class_name in provider.data["extra-links"]:
1121 if _correctness_check(provider_package, extra_link_class_name, provider):
1122 self._extra_link_class_name_set.add(extra_link_class_name)
1124 def _discover_logging(self) -> None:
1125 """Retrieve all logging defined in the providers."""
1126 for provider_package, provider in self._provider_dict.items():
1127 if provider.data.get("logging"):
1128 for logging_class_name in provider.data["logging"]:
1129 if _correctness_check(provider_package, logging_class_name, provider):
1130 self._logging_class_name_set.add(logging_class_name)
1132 def _discover_secrets_backends(self) -> None:
1133 """Retrieve all secrets backends defined in the providers."""
1134 for provider_package, provider in self._provider_dict.items():
1135 if provider.data.get("secrets-backends"):
1136 for secrets_backends_class_name in provider.data["secrets-backends"]:
1137 if _correctness_check(provider_package, secrets_backends_class_name, provider):
1138 self._secrets_backend_class_name_set.add(secrets_backends_class_name)
1140 def _discover_auth_backends(self) -> None:
1141 """Retrieve all API auth backends defined in the providers."""
1142 for provider_package, provider in self._provider_dict.items():
1143 if provider.data.get("auth-backends"):
1144 for auth_backend_module_name in provider.data["auth-backends"]:
1145 if _correctness_check(provider_package, auth_backend_module_name + ".init_app", provider):
1146 self._api_auth_backend_module_names.add(auth_backend_module_name)
1148 def _discover_executors(self) -> None:
1149 """Retrieve all executors defined in the providers."""
1150 for provider_package, provider in self._provider_dict.items():
1151 if provider.data.get("executors"):
1152 for executors_class_name in provider.data["executors"]:
1153 if _correctness_check(provider_package, executors_class_name, provider):
1154 self._executor_class_name_set.add(executors_class_name)
1156 def _discover_config(self) -> None:
1157 """Retrieve all configs defined in the providers."""
1158 for provider_package, provider in self._provider_dict.items():
1159 if provider.data.get("config"):
1160 self._provider_configs[provider_package] = provider.data.get("config") # type: ignore[assignment]
1162 def _discover_plugins(self) -> None:
1163 """Retrieve all plugins defined in the providers."""
1164 for provider_package, provider in self._provider_dict.items():
1165 for plugin_dict in provider.data.get("plugins", ()):
1166 if not _correctness_check(provider_package, plugin_dict["plugin-class"], provider):
1167 log.warning("Plugin not loaded due to above correctness check problem.")
1168 continue
1169 self._plugins_set.add(
1170 PluginInfo(
1171 name=plugin_dict["name"],
1172 plugin_class=plugin_dict["plugin-class"],
1173 provider_name=provider_package,
1174 )
1175 )
1177 @provider_info_cache("triggers")
1178 def initialize_providers_triggers(self):
1179 """Initialize providers triggers."""
1180 self.initialize_providers_list()
1181 for provider_package, provider in self._provider_dict.items():
1182 for trigger in provider.data.get("triggers", []):
1183 for trigger_class_name in trigger.get("python-modules"):
1184 self._trigger_info_set.add(
1185 TriggerInfo(
1186 package_name=provider_package,
1187 trigger_class_name=trigger_class_name,
1188 integration_name=trigger.get("integration-name", ""),
1189 )
1190 )
1192 @property
1193 def auth_managers(self) -> list[str]:
1194 """Returns information about available providers notifications class."""
1195 self.initialize_providers_auth_managers()
1196 return sorted(self._auth_manager_class_name_set)
1198 @property
1199 def notification(self) -> list[NotificationInfo]:
1200 """Returns information about available providers notifications class."""
1201 self.initialize_providers_notifications()
1202 return sorted(self._notification_info_set)
1204 @property
1205 def trigger(self) -> list[TriggerInfo]:
1206 """Returns information about available providers trigger class."""
1207 self.initialize_providers_triggers()
1208 return sorted(self._trigger_info_set, key=lambda x: x.package_name)
1210 @property
1211 def providers(self) -> dict[str, ProviderInfo]:
1212 """Returns information about available providers."""
1213 self.initialize_providers_list()
1214 return self._provider_dict
1216 @property
1217 def hooks(self) -> MutableMapping[str, HookInfo | None]:
1218 """
1219 Return dictionary of connection_type-to-hook mapping.
1221 Note that the dict can contain None values if a hook discovered cannot be imported!
1222 """
1223 self.initialize_providers_hooks()
1224 # When we return hooks here it will only be used to retrieve hook information
1225 return self._hooks_lazy_dict
1227 @property
1228 def plugins(self) -> list[PluginInfo]:
1229 """Returns information about plugins available in providers."""
1230 self.initialize_providers_plugins()
1231 return sorted(self._plugins_set, key=lambda x: x.plugin_class)
1233 @property
1234 def taskflow_decorators(self) -> dict[str, TaskDecorator]:
1235 self.initialize_providers_taskflow_decorator()
1236 return self._taskflow_decorators # type: ignore[return-value]
1238 @property
1239 def extra_links_class_names(self) -> list[str]:
1240 """Returns set of extra link class names."""
1241 self.initialize_providers_extra_links()
1242 return sorted(self._extra_link_class_name_set)
1244 @property
1245 def connection_form_widgets(self) -> dict[str, ConnectionFormWidgetInfo]:
1246 """
1247 Returns widgets for connection forms.
1249 Dictionary keys in the same order that it defined in Hook.
1250 """
1251 self.initialize_providers_hooks()
1252 self._import_info_from_all_hooks()
1253 return self._connection_form_widgets
1255 @property
1256 def field_behaviours(self) -> dict[str, dict]:
1257 """Returns dictionary with field behaviours for connection types."""
1258 self.initialize_providers_hooks()
1259 self._import_info_from_all_hooks()
1260 return self._field_behaviours
1262 @property
1263 def logging_class_names(self) -> list[str]:
1264 """Returns set of log task handlers class names."""
1265 self.initialize_providers_logging()
1266 return sorted(self._logging_class_name_set)
1268 @property
1269 def secrets_backend_class_names(self) -> list[str]:
1270 """Returns set of secret backend class names."""
1271 self.initialize_providers_secrets_backends()
1272 return sorted(self._secrets_backend_class_name_set)
1274 @property
1275 def auth_backend_module_names(self) -> list[str]:
1276 """Returns set of API auth backend class names."""
1277 self.initialize_providers_auth_backends()
1278 return sorted(self._api_auth_backend_module_names)
1280 @property
1281 def executor_class_names(self) -> list[str]:
1282 self.initialize_providers_executors()
1283 return sorted(self._executor_class_name_set)
1285 @property
1286 def filesystem_module_names(self) -> list[str]:
1287 self.initialize_providers_filesystems()
1288 return sorted(self._fs_set)
1290 @property
1291 def dataset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]:
1292 self.initializa_providers_dataset_uri_handlers()
1293 return self._dataset_uri_handlers
1295 @property
1296 def provider_configs(self) -> list[tuple[str, dict[str, Any]]]:
1297 self.initialize_providers_configuration()
1298 return sorted(self._provider_configs.items(), key=lambda x: x[0])
1300 @property
1301 def already_initialized_provider_configs(self) -> list[tuple[str, dict[str, Any]]]:
1302 return sorted(self._provider_configs.items(), key=lambda x: x[0])
1304 def _cleanup(self):
1305 self._initialized_cache.clear()
1306 self._provider_dict.clear()
1307 self._hooks_dict.clear()
1308 self._fs_set.clear()
1309 self._taskflow_decorators.clear()
1310 self._hook_provider_dict.clear()
1311 self._hooks_lazy_dict.clear()
1312 self._connection_form_widgets.clear()
1313 self._field_behaviours.clear()
1314 self._extra_link_class_name_set.clear()
1315 self._logging_class_name_set.clear()
1316 self._auth_manager_class_name_set.clear()
1317 self._secrets_backend_class_name_set.clear()
1318 self._executor_class_name_set.clear()
1319 self._provider_configs.clear()
1320 self._api_auth_backend_module_names.clear()
1321 self._trigger_info_set.clear()
1322 self._notification_info_set.clear()
1323 self._plugins_set.clear()
1324 self._initialized = False
1325 self._initialization_stack_trace = None