1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Shared provider discovery utilities."""
19
20from __future__ import annotations
21
22import contextlib
23import json
24import pathlib
25from collections.abc import Callable, MutableMapping
26from dataclasses import dataclass
27from functools import wraps
28from importlib.resources import files as resource_files
29from time import perf_counter
30from typing import Any, NamedTuple, ParamSpec, Protocol, cast
31
32import structlog
33from packaging.utils import canonicalize_name
34
35from ..module_loading import entry_points_with_dist
36
37log = structlog.getLogger(__name__)
38
39
40PS = ParamSpec("PS")
41
42
43KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS = [("apache-airflow-providers-google", "No module named 'paramiko'")]
44
45
46class ProvidersManagerProtocol(Protocol):
47 """Protocol for ProvidersManager for type checking purposes."""
48
49 _initialized_cache: dict[str, bool]
50
51
52@dataclass
53class ProviderInfo:
54 """
55 Provider information.
56
57 :param version: version string
58 :param data: dictionary with information about the provider
59 """
60
61 version: str
62 data: dict
63
64
65class HookClassProvider(NamedTuple):
66 """Hook class and Provider it comes from."""
67
68 hook_class_name: str
69 package_name: str
70
71
72class HookInfo(NamedTuple):
73 """Hook information."""
74
75 hook_class_name: str
76 connection_id_attribute_name: str
77 package_name: str
78 hook_name: str
79 connection_type: str
80 connection_testable: bool
81 dialects: list[str] = []
82
83
84class ConnectionFormWidgetInfo(NamedTuple):
85 """Connection Form Widget information."""
86
87 hook_class_name: str
88 package_name: str
89 field: Any
90 field_name: str
91 is_sensitive: bool
92
93
94class PluginInfo(NamedTuple):
95 """Plugin class, name and provider it comes from."""
96
97 name: str
98 plugin_class: str
99 provider_name: str
100
101
102class NotificationInfo(NamedTuple):
103 """Notification class and provider it comes from."""
104
105 notification_class_name: str
106 package_name: str
107
108
109class TriggerInfo(NamedTuple):
110 """Trigger class and provider it comes from."""
111
112 trigger_class_name: str
113 package_name: str
114 integration_name: str
115
116
117class DialectInfo(NamedTuple):
118 """Dialect class and Provider it comes from."""
119
120 name: str
121 dialect_class_name: str
122 provider_name: str
123
124
125class LazyDictWithCache(MutableMapping):
126 """
127 Lazy-loaded cached dictionary.
128
129 Dictionary, which in case you set callable, executes the passed callable with `key` attribute
130 at first use - and returns and caches the result.
131 """
132
133 __slots__ = ["_resolved", "_raw_dict"]
134
135 def __init__(self, *args, **kw):
136 self._resolved = set()
137 self._raw_dict = dict(*args, **kw)
138
139 def __setitem__(self, key, value):
140 self._raw_dict.__setitem__(key, value)
141
142 def __getitem__(self, key):
143 value = self._raw_dict.__getitem__(key)
144 if key not in self._resolved and callable(value):
145 # exchange callable with result of calling it -- but only once! allow resolver to return a
146 # callable itself
147 value = value()
148 self._resolved.add(key)
149 self._raw_dict.__setitem__(key, value)
150 return value
151
152 def __delitem__(self, key):
153 with contextlib.suppress(KeyError):
154 self._resolved.remove(key)
155 self._raw_dict.__delitem__(key)
156
157 def __iter__(self):
158 return iter(self._raw_dict)
159
160 def __len__(self):
161 return len(self._raw_dict)
162
163 def __contains__(self, key):
164 return key in self._raw_dict
165
166 def clear(self):
167 self._resolved.clear()
168 self._raw_dict.clear()
169
170
171def _read_schema_from_resources_or_local_file(filename: str) -> dict:
172 """Read JSON schema from resources or local file."""
173 try:
174 with resource_files("airflow").joinpath(filename).open("rb") as f:
175 schema = json.load(f)
176 except (TypeError, FileNotFoundError):
177 with (pathlib.Path(__file__).parent / filename).open("rb") as f:
178 schema = json.load(f)
179 return schema
180
181
182def _create_provider_info_schema_validator():
183 """Create JSON schema validator from the provider_info.schema.json."""
184 import jsonschema
185
186 schema = _read_schema_from_resources_or_local_file("provider_info.schema.json")
187 cls = jsonschema.validators.validator_for(schema)
188 validator = cls(schema)
189 return validator
190
191
192def _create_customized_form_field_behaviours_schema_validator():
193 """Create JSON schema validator from the customized_form_field_behaviours.schema.json."""
194 import jsonschema
195
196 schema = _read_schema_from_resources_or_local_file("customized_form_field_behaviours.schema.json")
197 cls = jsonschema.validators.validator_for(schema)
198 validator = cls(schema)
199 return validator
200
201
202def _check_builtin_provider_prefix(provider_package: str, class_name: str) -> bool:
203 """Check if builtin provider class has correct prefix."""
204 if provider_package.startswith("apache-airflow"):
205 provider_path = provider_package[len("apache-") :].replace("-", ".")
206 if not class_name.startswith(provider_path):
207 log.warning(
208 "Coherence check failed when importing '%s' from '%s' package. It should start with '%s'",
209 class_name,
210 provider_package,
211 provider_path,
212 )
213 return False
214 return True
215
216
217def _ensure_prefix_for_placeholders(field_behaviors: dict[str, Any], conn_type: str):
218 """
219 Verify the correct placeholder prefix.
220
221 If the given field_behaviors dict contains a placeholder's node, and there
222 are placeholders for extra fields (i.e. anything other than the built-in conn
223 attrs), and if those extra fields are unprefixed, then add the prefix.
224
225 The reason we need to do this is, all custom conn fields live in the same dictionary,
226 so we need to namespace them with a prefix internally. But for user convenience,
227 and consistency between the `get_ui_field_behaviour` method and the extra dict itself,
228 we allow users to supply the unprefixed name.
229 """
230 conn_attrs = {"host", "schema", "login", "password", "port", "extra"}
231
232 def ensure_prefix(field):
233 if field not in conn_attrs and not field.startswith("extra__"):
234 return f"extra__{conn_type}__{field}"
235 return field
236
237 if "placeholders" in field_behaviors:
238 placeholders = field_behaviors["placeholders"]
239 field_behaviors["placeholders"] = {ensure_prefix(k): v for k, v in placeholders.items()}
240
241 return field_behaviors
242
243
244def log_optional_feature_disabled(class_name, e, provider_package):
245 """Log optional feature disabled."""
246 log.debug(
247 "Optional feature disabled on exception when importing '%s' from '%s' package",
248 class_name,
249 provider_package,
250 exc_info=e,
251 )
252 log.info(
253 "Optional provider feature disabled when importing '%s' from '%s' package",
254 class_name,
255 provider_package,
256 )
257
258
259def log_import_warning(class_name, e, provider_package):
260 """Log import warning."""
261 log.warning(
262 "Exception when importing '%s' from '%s' package",
263 class_name,
264 provider_package,
265 exc_info=e,
266 )
267
268
269def provider_info_cache(cache_name: str) -> Callable[[Callable[PS, None]], Callable[PS, None]]:
270 """
271 Decorate and cache provider info.
272
273 Decorator factory that create decorator that caches initialization of provider's parameters
274 :param cache_name: Name of the cache
275 """
276
277 def provider_info_cache_decorator(func: Callable[PS, None]) -> Callable[PS, None]:
278 @wraps(func)
279 def wrapped_function(*args: PS.args, **kwargs: PS.kwargs) -> None:
280 instance = cast("ProvidersManagerProtocol", args[0])
281
282 if cache_name in instance._initialized_cache:
283 return
284 start_time = perf_counter()
285 log.debug("Initializing Provider Manager[%s]", cache_name)
286 func(*args, **kwargs)
287 instance._initialized_cache[cache_name] = True
288 log.debug(
289 "Initialization of Provider Manager[%s] took %.2f seconds",
290 cache_name,
291 perf_counter() - start_time,
292 )
293
294 return wrapped_function
295
296 return provider_info_cache_decorator
297
298
299def discover_all_providers_from_packages(
300 provider_dict: dict[str, ProviderInfo],
301 provider_schema_validator,
302) -> None:
303 """
304 Discover all providers by scanning packages installed.
305
306 The list of providers should be returned via the 'apache_airflow_provider'
307 entrypoint as a dictionary conforming to the 'airflow/provider_info.schema.json'
308 schema. Note that the schema is different at runtime than provider.yaml.schema.json.
309 The development version of provider schema is more strict and changes together with
310 the code. The runtime version is more relaxed (allows for additional properties)
311 and verifies only the subset of fields that are needed at runtime.
312
313 :param provider_dict: Dictionary to populate with discovered providers
314 :param provider_schema_validator: JSON schema validator for provider info
315 """
316 for entry_point, dist in entry_points_with_dist("apache_airflow_provider"):
317 if not dist.metadata:
318 continue
319 package_name = canonicalize_name(dist.metadata["name"])
320 if package_name in provider_dict:
321 continue
322 log.debug("Loading %s from package %s", entry_point, package_name)
323 version = dist.version
324 provider_info = entry_point.load()()
325 provider_schema_validator.validate(provider_info)
326 provider_info_package_name = provider_info["package-name"]
327 if package_name != provider_info_package_name:
328 raise ValueError(
329 f"The package '{package_name}' from packaging information "
330 f"{provider_info_package_name} do not match. Please make sure they are aligned"
331 )
332
333 # issue-59576: Retrieve the project.urls.documentation from dist.metadata
334 project_urls = dist.metadata.get_all("Project-URL")
335 documentation_url: str | None = None
336
337 if project_urls:
338 for entry in project_urls:
339 if "," in entry:
340 name, url = entry.split(",")
341 if name.strip().lower() == "documentation":
342 documentation_url = url
343 break
344
345 provider_info["documentation-url"] = documentation_url
346
347 if package_name not in provider_dict:
348 provider_dict[package_name] = ProviderInfo(version, provider_info)
349 else:
350 log.warning(
351 "The provider for package '%s' could not be registered from because providers for that "
352 "package name have already been registered",
353 package_name,
354 )