Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/connection.py: 25%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import json
21import logging
22import warnings
23from json import JSONDecodeError
24from typing import Any
25from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit
27import re2
28from sqlalchemy import Boolean, Column, Integer, String, Text
29from sqlalchemy.orm import declared_attr, reconstructor, synonym
31from airflow.configuration import ensure_secrets_loaded
32from airflow.exceptions import AirflowException, AirflowNotFoundException, RemovedInAirflow3Warning
33from airflow.models.base import ID_LEN, Base
34from airflow.models.crypto import get_fernet
35from airflow.secrets.cache import SecretCache
36from airflow.utils.helpers import prune_dict
37from airflow.utils.log.logging_mixin import LoggingMixin
38from airflow.utils.log.secrets_masker import mask_secret
39from airflow.utils.module_loading import import_string
41log = logging.getLogger(__name__)
42# sanitize the `conn_id` pattern by allowing alphanumeric characters plus
43# the symbols #,!,-,_,.,:,\,/ and () requiring at least one match.
44#
45# You can try the regex here: https://regex101.com/r/69033B/1
46RE_SANITIZE_CONN_ID = re2.compile(r"^[\w\#\!\(\)\-\.\:\/\\]{1,}$")
47# the conn ID max len should be 250
48CONN_ID_MAX_LEN: int = 250
51def parse_netloc_to_hostname(*args, **kwargs):
52 """Do not use, this method is deprecated."""
53 warnings.warn("This method is deprecated.", RemovedInAirflow3Warning, stacklevel=2)
54 return _parse_netloc_to_hostname(*args, **kwargs)
57def sanitize_conn_id(conn_id: str | None, max_length=CONN_ID_MAX_LEN) -> str | None:
58 r"""Sanitizes the connection id and allows only specific characters to be within.
60 Namely, it allows alphanumeric characters plus the symbols #,!,-,_,.,:,\,/ and () from 1 and up to
61 250 consecutive matches. If desired, the max length can be adjusted by setting `max_length`.
63 You can try to play with the regex here: https://regex101.com/r/69033B/1
65 The character selection is such that it prevents the injection of javascript or
66 executable bits to avoid any awkward behaviour in the front-end.
68 :param conn_id: The connection id to sanitize.
69 :param max_length: The max length of the connection ID, by default it is 250.
70 :return: the sanitized string, `None` otherwise.
71 """
72 # check if `conn_id` or our match group is `None` and the `conn_id` is within the specified length.
73 if (not isinstance(conn_id, str) or len(conn_id) > max_length) or (
74 res := re2.match(RE_SANITIZE_CONN_ID, conn_id)
75 ) is None:
76 return None
78 # if we reach here, then we matched something, return the first match
79 return res.group(0)
82def _parse_netloc_to_hostname(uri_parts):
83 """
84 Parse a URI string to get the correct Hostname.
86 ``urlparse(...).hostname`` or ``urlsplit(...).hostname`` returns value into the lowercase in most cases,
87 there are some exclusion exists for specific cases such as https://bugs.python.org/issue32323
88 In case if expected to get a path as part of hostname path,
89 then default behavior ``urlparse``/``urlsplit`` is unexpected.
90 """
91 hostname = unquote(uri_parts.hostname or "")
92 if "/" in hostname:
93 hostname = uri_parts.netloc
94 if "@" in hostname:
95 hostname = hostname.rsplit("@", 1)[1]
96 if ":" in hostname:
97 hostname = hostname.split(":", 1)[0]
98 hostname = unquote(hostname)
99 return hostname
102class Connection(Base, LoggingMixin):
103 """
104 Placeholder to store information about different database instances connection information.
106 The idea here is that scripts use references to database instances (conn_id)
107 instead of hard coding hostname, logins and passwords when using operators or hooks.
109 .. seealso::
110 For more information on how to use this class, see: :doc:`/howto/connection`
112 :param conn_id: The connection ID.
113 :param conn_type: The connection type.
114 :param description: The connection description.
115 :param host: The host.
116 :param login: The login.
117 :param password: The password.
118 :param schema: The schema.
119 :param port: The port number.
120 :param extra: Extra metadata. Non-standard data such as private/SSH keys can be saved here. JSON
121 encoded object.
122 :param uri: URI address describing connection parameters.
123 """
125 EXTRA_KEY = "__extra__"
127 __tablename__ = "connection"
129 id = Column(Integer(), primary_key=True)
130 conn_id = Column(String(ID_LEN), unique=True, nullable=False)
131 conn_type = Column(String(500), nullable=False)
132 description = Column(Text().with_variant(Text(5000), "mysql").with_variant(String(5000), "sqlite"))
133 host = Column(String(500))
134 schema = Column(String(500))
135 login = Column(Text())
136 _password = Column("password", Text())
137 port = Column(Integer())
138 is_encrypted = Column(Boolean, unique=False, default=False)
139 is_extra_encrypted = Column(Boolean, unique=False, default=False)
140 _extra = Column("extra", Text())
142 def __init__(
143 self,
144 conn_id: str | None = None,
145 conn_type: str | None = None,
146 description: str | None = None,
147 host: str | None = None,
148 login: str | None = None,
149 password: str | None = None,
150 schema: str | None = None,
151 port: int | None = None,
152 extra: str | dict | None = None,
153 uri: str | None = None,
154 ):
155 super().__init__()
156 self.conn_id = sanitize_conn_id(conn_id)
157 self.description = description
158 if extra and not isinstance(extra, str):
159 extra = json.dumps(extra)
160 if uri and (conn_type or host or login or password or schema or port or extra):
161 raise AirflowException(
162 "You must create an object using the URI or individual values "
163 "(conn_type, host, login, password, schema, port or extra)."
164 "You can't mix these two ways to create this object."
165 )
166 if uri:
167 self._parse_from_uri(uri)
168 else:
169 self.conn_type = conn_type
170 self.host = host
171 self.login = login
172 self.password = password
173 self.schema = schema
174 self.port = port
175 self.extra = extra
176 if self.extra:
177 self._validate_extra(self.extra, self.conn_id)
179 if self.password:
180 mask_secret(self.password)
181 mask_secret(quote(self.password))
183 @staticmethod
184 def _validate_extra(extra, conn_id) -> None:
185 """
186 Verify that ``extra`` is a JSON-encoded Python dict.
188 From Airflow 3.0, we should no longer suppress these errors but raise instead.
189 """
190 if extra is None:
191 return None
192 try:
193 extra_parsed = json.loads(extra)
194 if not isinstance(extra_parsed, dict):
195 warnings.warn(
196 "Encountered JSON value in `extra` which does not parse as a dictionary in "
197 f"connection {conn_id!r}. From Airflow 3.0, the `extra` field must contain a JSON "
198 "representation of a Python dict.",
199 RemovedInAirflow3Warning,
200 stacklevel=3,
201 )
202 except json.JSONDecodeError:
203 warnings.warn(
204 f"Encountered non-JSON in `extra` field for connection {conn_id!r}. Support for "
205 "non-JSON `extra` will be removed in Airflow 3.0",
206 RemovedInAirflow3Warning,
207 stacklevel=2,
208 )
209 return None
211 @reconstructor
212 def on_db_load(self):
213 if self.password:
214 mask_secret(self.password)
215 mask_secret(quote(self.password))
217 def parse_from_uri(self, **uri):
218 """Use uri parameter in constructor, this method is deprecated."""
219 warnings.warn(
220 "This method is deprecated. Please use uri parameter in constructor.",
221 RemovedInAirflow3Warning,
222 stacklevel=2,
223 )
224 self._parse_from_uri(**uri)
226 @staticmethod
227 def _normalize_conn_type(conn_type):
228 if conn_type == "postgresql":
229 conn_type = "postgres"
230 elif "-" in conn_type:
231 conn_type = conn_type.replace("-", "_")
232 return conn_type
234 def _parse_from_uri(self, uri: str):
235 schemes_count_in_uri = uri.count("://")
236 if schemes_count_in_uri > 2:
237 raise AirflowException(f"Invalid connection string: {uri}.")
238 host_with_protocol = schemes_count_in_uri == 2
239 uri_parts = urlsplit(uri)
240 conn_type = uri_parts.scheme
241 self.conn_type = self._normalize_conn_type(conn_type)
242 rest_of_the_url = uri.replace(f"{conn_type}://", ("" if host_with_protocol else "//"))
243 if host_with_protocol:
244 uri_splits = rest_of_the_url.split("://", 1)
245 if "@" in uri_splits[0] or ":" in uri_splits[0]:
246 raise AirflowException(f"Invalid connection string: {uri}.")
247 uri_parts = urlsplit(rest_of_the_url)
248 protocol = uri_parts.scheme if host_with_protocol else None
249 host = _parse_netloc_to_hostname(uri_parts)
250 self.host = self._create_host(protocol, host)
251 quoted_schema = uri_parts.path[1:]
252 self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema
253 self.login = unquote(uri_parts.username) if uri_parts.username else uri_parts.username
254 self.password = unquote(uri_parts.password) if uri_parts.password else uri_parts.password
255 self.port = uri_parts.port
256 if uri_parts.query:
257 query = dict(parse_qsl(uri_parts.query, keep_blank_values=True))
258 if self.EXTRA_KEY in query:
259 self.extra = query[self.EXTRA_KEY]
260 else:
261 self.extra = json.dumps(query)
263 @staticmethod
264 def _create_host(protocol, host) -> str | None:
265 """Return the connection host with the protocol."""
266 if not host:
267 return host
268 if protocol:
269 return f"{protocol}://{host}"
270 return host
272 def get_uri(self) -> str:
273 """Return connection in URI format."""
274 if self.conn_type and "_" in self.conn_type:
275 self.log.warning(
276 "Connection schemes (type: %s) shall not contain '_' according to RFC3986.",
277 self.conn_type,
278 )
280 if self.conn_type:
281 uri = f"{self.conn_type.lower().replace('_', '-')}://"
282 else:
283 uri = "//"
285 if self.host and "://" in self.host:
286 protocol, host = self.host.split("://", 1)
287 else:
288 protocol, host = None, self.host
290 if protocol:
291 uri += f"{protocol}://"
293 authority_block = ""
294 if self.login is not None:
295 authority_block += quote(self.login, safe="")
297 if self.password is not None:
298 authority_block += ":" + quote(self.password, safe="")
300 if authority_block > "":
301 authority_block += "@"
303 uri += authority_block
305 host_block = ""
306 if host:
307 host_block += quote(host, safe="")
309 if self.port:
310 if host_block == "" and authority_block == "":
311 host_block += f"@:{self.port}"
312 else:
313 host_block += f":{self.port}"
315 if self.schema:
316 host_block += f"/{quote(self.schema, safe='')}"
318 uri += host_block
320 if self.extra:
321 try:
322 query: str | None = urlencode(self.extra_dejson)
323 except TypeError:
324 query = None
325 if query and self.extra_dejson == dict(parse_qsl(query, keep_blank_values=True)):
326 uri += ("?" if self.schema else "/?") + query
327 else:
328 uri += ("?" if self.schema else "/?") + urlencode({self.EXTRA_KEY: self.extra})
330 return uri
332 def get_password(self) -> str | None:
333 """Return encrypted password."""
334 if self._password and self.is_encrypted:
335 fernet = get_fernet()
336 if not fernet.is_encrypted:
337 raise AirflowException(
338 f"Can't decrypt encrypted password for login={self.login} "
339 f"FERNET_KEY configuration is missing"
340 )
341 return fernet.decrypt(bytes(self._password, "utf-8")).decode()
342 else:
343 return self._password
345 def set_password(self, value: str | None):
346 """Encrypt password and set in object attribute."""
347 if value:
348 fernet = get_fernet()
349 self._password = fernet.encrypt(bytes(value, "utf-8")).decode()
350 self.is_encrypted = fernet.is_encrypted
352 @declared_attr
353 def password(cls):
354 """Password. The value is decrypted/encrypted when reading/setting the value."""
355 return synonym("_password", descriptor=property(cls.get_password, cls.set_password))
357 def get_extra(self) -> str:
358 """Return encrypted extra-data."""
359 if self._extra and self.is_extra_encrypted:
360 fernet = get_fernet()
361 if not fernet.is_encrypted:
362 raise AirflowException(
363 f"Can't decrypt `extra` params for login={self.login}, "
364 f"FERNET_KEY configuration is missing"
365 )
366 extra_val = fernet.decrypt(bytes(self._extra, "utf-8")).decode()
367 else:
368 extra_val = self._extra
369 if extra_val:
370 self._validate_extra(extra_val, self.conn_id)
371 return extra_val
373 def set_extra(self, value: str):
374 """Encrypt extra-data and save in object attribute to object."""
375 if value:
376 self._validate_extra(value, self.conn_id)
377 fernet = get_fernet()
378 self._extra = fernet.encrypt(bytes(value, "utf-8")).decode()
379 self.is_extra_encrypted = fernet.is_encrypted
380 else:
381 self._extra = value
382 self.is_extra_encrypted = False
384 @declared_attr
385 def extra(cls):
386 """Extra data. The value is decrypted/encrypted when reading/setting the value."""
387 return synonym("_extra", descriptor=property(cls.get_extra, cls.set_extra))
389 def rotate_fernet_key(self):
390 """Encrypts data with a new key. See: :ref:`security/fernet`."""
391 fernet = get_fernet()
392 if self._password and self.is_encrypted:
393 self._password = fernet.rotate(self._password.encode("utf-8")).decode()
394 if self._extra and self.is_extra_encrypted:
395 self._extra = fernet.rotate(self._extra.encode("utf-8")).decode()
397 def get_hook(self, *, hook_params=None):
398 """Return hook based on conn_type."""
399 from airflow.providers_manager import ProvidersManager
401 hook = ProvidersManager().hooks.get(self.conn_type, None)
403 if hook is None:
404 raise AirflowException(f'Unknown hook type "{self.conn_type}"')
405 try:
406 hook_class = import_string(hook.hook_class_name)
407 except ImportError:
408 log.error(
409 "Could not import %s when discovering %s %s",
410 hook.hook_class_name,
411 hook.hook_name,
412 hook.package_name,
413 )
414 raise
415 if hook_params is None:
416 hook_params = {}
417 return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params)
419 def __repr__(self):
420 return self.conn_id or ""
422 def log_info(self):
423 """
424 Read each field individually or use the default representation (`__repr__`).
426 This method is deprecated.
427 """
428 warnings.warn(
429 "This method is deprecated. You can read each field individually or "
430 "use the default representation (__repr__).",
431 RemovedInAirflow3Warning,
432 stacklevel=2,
433 )
434 return (
435 f"id: {self.conn_id}. Host: {self.host}, Port: {self.port}, Schema: {self.schema}, "
436 f"Login: {self.login}, Password: {'XXXXXXXX' if self.password else None}, "
437 f"extra: {'XXXXXXXX' if self.extra_dejson else None}"
438 )
440 def debug_info(self):
441 """
442 Read each field individually or use the default representation (`__repr__`).
444 This method is deprecated.
445 """
446 warnings.warn(
447 "This method is deprecated. You can read each field individually or "
448 "use the default representation (__repr__).",
449 RemovedInAirflow3Warning,
450 stacklevel=2,
451 )
452 return (
453 f"id: {self.conn_id}. Host: {self.host}, Port: {self.port}, Schema: {self.schema}, "
454 f"Login: {self.login}, Password: {'XXXXXXXX' if self.password else None}, "
455 f"extra: {self.extra_dejson}"
456 )
458 def test_connection(self):
459 """Calls out get_hook method and executes test_connection method on that."""
460 status, message = False, ""
461 try:
462 hook = self.get_hook()
463 if getattr(hook, "test_connection", False):
464 status, message = hook.test_connection()
465 else:
466 message = (
467 f"Hook {hook.__class__.__name__} doesn't implement or inherit test_connection method"
468 )
469 except Exception as e:
470 message = str(e)
472 return status, message
474 @property
475 def extra_dejson(self) -> dict:
476 """Returns the extra property by deserializing json."""
477 obj = {}
478 if self.extra:
479 try:
480 obj = json.loads(self.extra)
482 except JSONDecodeError:
483 self.log.exception("Failed parsing the json for conn_id %s", self.conn_id)
485 # Mask sensitive keys from this list
486 mask_secret(obj)
488 return obj
490 @classmethod
491 def get_connection_from_secrets(cls, conn_id: str) -> Connection:
492 """
493 Get connection by conn_id.
495 :param conn_id: connection id
496 :return: connection
497 """
498 # check cache first
499 # enabled only if SecretCache.init() has been called first
500 try:
501 uri = SecretCache.get_connection_uri(conn_id)
502 return Connection(conn_id=conn_id, uri=uri)
503 except SecretCache.NotPresentException:
504 pass # continue business
506 # iterate over backends if not in cache (or expired)
507 for secrets_backend in ensure_secrets_loaded():
508 try:
509 conn = secrets_backend.get_connection(conn_id=conn_id)
510 if conn:
511 SecretCache.save_connection_uri(conn_id, conn.get_uri())
512 return conn
513 except Exception:
514 log.exception(
515 "Unable to retrieve connection from secrets backend (%s). "
516 "Checking subsequent secrets backend.",
517 type(secrets_backend).__name__,
518 )
520 raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined")
522 def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[str, Any]:
523 """
524 Convert Connection to json-serializable dictionary.
526 :param prune_empty: Whether or not remove empty values.
527 :param validate: Validate dictionary is JSON-serializable
529 :meta private:
530 """
531 conn = {
532 "conn_id": self.conn_id,
533 "conn_type": self.conn_type,
534 "description": self.description,
535 "host": self.host,
536 "login": self.login,
537 "password": self.password,
538 "schema": self.schema,
539 "port": self.port,
540 }
541 if prune_empty:
542 conn = prune_dict(val=conn, mode="strict")
543 if (extra := self.extra_dejson) or not prune_empty:
544 conn["extra"] = extra
546 if validate:
547 json.dumps(conn)
548 return conn
550 @classmethod
551 def from_json(cls, value, conn_id=None) -> Connection:
552 kwargs = json.loads(value)
553 extra = kwargs.pop("extra", None)
554 if extra:
555 kwargs["extra"] = extra if isinstance(extra, str) else json.dumps(extra)
556 conn_type = kwargs.pop("conn_type", None)
557 if conn_type:
558 kwargs["conn_type"] = cls._normalize_conn_type(conn_type)
559 port = kwargs.pop("port", None)
560 if port:
561 try:
562 kwargs["port"] = int(port)
563 except ValueError:
564 raise ValueError(f"Expected integer value for `port`, but got {port!r} instead.")
565 return Connection(conn_id=conn_id, **kwargs)
567 def as_json(self) -> str:
568 """Convert Connection to JSON-string object."""
569 conn_repr = self.to_dict(prune_empty=True, validate=False)
570 conn_repr.pop("conn_id", None)
571 return json.dumps(conn_repr)