Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/connection.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

302 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import json 

21import logging 

22import warnings 

23from json import JSONDecodeError 

24from typing import Any 

25from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit 

26 

27import re2 

28from sqlalchemy import Boolean, Column, Integer, String, Text 

29from sqlalchemy.orm import declared_attr, reconstructor, synonym 

30 

31from airflow.configuration import ensure_secrets_loaded 

32from airflow.exceptions import AirflowException, AirflowNotFoundException, RemovedInAirflow3Warning 

33from airflow.models.base import ID_LEN, Base 

34from airflow.models.crypto import get_fernet 

35from airflow.secrets.cache import SecretCache 

36from airflow.utils.helpers import prune_dict 

37from airflow.utils.log.logging_mixin import LoggingMixin 

38from airflow.utils.log.secrets_masker import mask_secret 

39from airflow.utils.module_loading import import_string 

40 

41log = logging.getLogger(__name__) 

42# sanitize the `conn_id` pattern by allowing alphanumeric characters plus 

43# the symbols #,!,-,_,.,:,\,/ and () requiring at least one match. 

44# 

45# You can try the regex here: https://regex101.com/r/69033B/1 

46RE_SANITIZE_CONN_ID = re2.compile(r"^[\w\#\!\(\)\-\.\:\/\\]{1,}$") 

47# the conn ID max len should be 250 

48CONN_ID_MAX_LEN: int = 250 

49 

50 

51def parse_netloc_to_hostname(*args, **kwargs): 

52 """Do not use, this method is deprecated.""" 

53 warnings.warn("This method is deprecated.", RemovedInAirflow3Warning, stacklevel=2) 

54 return _parse_netloc_to_hostname(*args, **kwargs) 

55 

56 

57def sanitize_conn_id(conn_id: str | None, max_length=CONN_ID_MAX_LEN) -> str | None: 

58 r"""Sanitizes the connection id and allows only specific characters to be within. 

59 

60 Namely, it allows alphanumeric characters plus the symbols #,!,-,_,.,:,\,/ and () from 1 and up to 

61 250 consecutive matches. If desired, the max length can be adjusted by setting `max_length`. 

62 

63 You can try to play with the regex here: https://regex101.com/r/69033B/1 

64 

65 The character selection is such that it prevents the injection of javascript or 

66 executable bits to avoid any awkward behaviour in the front-end. 

67 

68 :param conn_id: The connection id to sanitize. 

69 :param max_length: The max length of the connection ID, by default it is 250. 

70 :return: the sanitized string, `None` otherwise. 

71 """ 

72 # check if `conn_id` or our match group is `None` and the `conn_id` is within the specified length. 

73 if (not isinstance(conn_id, str) or len(conn_id) > max_length) or ( 

74 res := re2.match(RE_SANITIZE_CONN_ID, conn_id) 

75 ) is None: 

76 return None 

77 

78 # if we reach here, then we matched something, return the first match 

79 return res.group(0) 

80 

81 

82def _parse_netloc_to_hostname(uri_parts): 

83 """ 

84 Parse a URI string to get the correct Hostname. 

85 

86 ``urlparse(...).hostname`` or ``urlsplit(...).hostname`` returns value into the lowercase in most cases, 

87 there are some exclusion exists for specific cases such as https://bugs.python.org/issue32323 

88 In case if expected to get a path as part of hostname path, 

89 then default behavior ``urlparse``/``urlsplit`` is unexpected. 

90 """ 

91 hostname = unquote(uri_parts.hostname or "") 

92 if "/" in hostname: 

93 hostname = uri_parts.netloc 

94 if "@" in hostname: 

95 hostname = hostname.rsplit("@", 1)[1] 

96 if ":" in hostname: 

97 hostname = hostname.split(":", 1)[0] 

98 hostname = unquote(hostname) 

99 return hostname 

100 

101 

102class Connection(Base, LoggingMixin): 

103 """ 

104 Placeholder to store information about different database instances connection information. 

105 

106 The idea here is that scripts use references to database instances (conn_id) 

107 instead of hard coding hostname, logins and passwords when using operators or hooks. 

108 

109 .. seealso:: 

110 For more information on how to use this class, see: :doc:`/howto/connection` 

111 

112 :param conn_id: The connection ID. 

113 :param conn_type: The connection type. 

114 :param description: The connection description. 

115 :param host: The host. 

116 :param login: The login. 

117 :param password: The password. 

118 :param schema: The schema. 

119 :param port: The port number. 

120 :param extra: Extra metadata. Non-standard data such as private/SSH keys can be saved here. JSON 

121 encoded object. 

122 :param uri: URI address describing connection parameters. 

123 """ 

124 

125 EXTRA_KEY = "__extra__" 

126 

127 __tablename__ = "connection" 

128 

129 id = Column(Integer(), primary_key=True) 

130 conn_id = Column(String(ID_LEN), unique=True, nullable=False) 

131 conn_type = Column(String(500), nullable=False) 

132 description = Column(Text().with_variant(Text(5000), "mysql").with_variant(String(5000), "sqlite")) 

133 host = Column(String(500)) 

134 schema = Column(String(500)) 

135 login = Column(Text()) 

136 _password = Column("password", Text()) 

137 port = Column(Integer()) 

138 is_encrypted = Column(Boolean, unique=False, default=False) 

139 is_extra_encrypted = Column(Boolean, unique=False, default=False) 

140 _extra = Column("extra", Text()) 

141 

142 def __init__( 

143 self, 

144 conn_id: str | None = None, 

145 conn_type: str | None = None, 

146 description: str | None = None, 

147 host: str | None = None, 

148 login: str | None = None, 

149 password: str | None = None, 

150 schema: str | None = None, 

151 port: int | None = None, 

152 extra: str | dict | None = None, 

153 uri: str | None = None, 

154 ): 

155 super().__init__() 

156 self.conn_id = sanitize_conn_id(conn_id) 

157 self.description = description 

158 if extra and not isinstance(extra, str): 

159 extra = json.dumps(extra) 

160 if uri and (conn_type or host or login or password or schema or port or extra): 

161 raise AirflowException( 

162 "You must create an object using the URI or individual values " 

163 "(conn_type, host, login, password, schema, port or extra)." 

164 "You can't mix these two ways to create this object." 

165 ) 

166 if uri: 

167 self._parse_from_uri(uri) 

168 else: 

169 self.conn_type = conn_type 

170 self.host = host 

171 self.login = login 

172 self.password = password 

173 self.schema = schema 

174 self.port = port 

175 self.extra = extra 

176 if self.extra: 

177 self._validate_extra(self.extra, self.conn_id) 

178 

179 if self.password: 

180 mask_secret(self.password) 

181 mask_secret(quote(self.password)) 

182 

183 @staticmethod 

184 def _validate_extra(extra, conn_id) -> None: 

185 """ 

186 Verify that ``extra`` is a JSON-encoded Python dict. 

187 

188 From Airflow 3.0, we should no longer suppress these errors but raise instead. 

189 """ 

190 if extra is None: 

191 return None 

192 try: 

193 extra_parsed = json.loads(extra) 

194 if not isinstance(extra_parsed, dict): 

195 warnings.warn( 

196 "Encountered JSON value in `extra` which does not parse as a dictionary in " 

197 f"connection {conn_id!r}. From Airflow 3.0, the `extra` field must contain a JSON " 

198 "representation of a Python dict.", 

199 RemovedInAirflow3Warning, 

200 stacklevel=3, 

201 ) 

202 except json.JSONDecodeError: 

203 warnings.warn( 

204 f"Encountered non-JSON in `extra` field for connection {conn_id!r}. Support for " 

205 "non-JSON `extra` will be removed in Airflow 3.0", 

206 RemovedInAirflow3Warning, 

207 stacklevel=2, 

208 ) 

209 return None 

210 

211 @reconstructor 

212 def on_db_load(self): 

213 if self.password: 

214 mask_secret(self.password) 

215 mask_secret(quote(self.password)) 

216 

217 def parse_from_uri(self, **uri): 

218 """Use uri parameter in constructor, this method is deprecated.""" 

219 warnings.warn( 

220 "This method is deprecated. Please use uri parameter in constructor.", 

221 RemovedInAirflow3Warning, 

222 stacklevel=2, 

223 ) 

224 self._parse_from_uri(**uri) 

225 

226 @staticmethod 

227 def _normalize_conn_type(conn_type): 

228 if conn_type == "postgresql": 

229 conn_type = "postgres" 

230 elif "-" in conn_type: 

231 conn_type = conn_type.replace("-", "_") 

232 return conn_type 

233 

234 def _parse_from_uri(self, uri: str): 

235 schemes_count_in_uri = uri.count("://") 

236 if schemes_count_in_uri > 2: 

237 raise AirflowException(f"Invalid connection string: {uri}.") 

238 host_with_protocol = schemes_count_in_uri == 2 

239 uri_parts = urlsplit(uri) 

240 conn_type = uri_parts.scheme 

241 self.conn_type = self._normalize_conn_type(conn_type) 

242 rest_of_the_url = uri.replace(f"{conn_type}://", ("" if host_with_protocol else "//")) 

243 if host_with_protocol: 

244 uri_splits = rest_of_the_url.split("://", 1) 

245 if "@" in uri_splits[0] or ":" in uri_splits[0]: 

246 raise AirflowException(f"Invalid connection string: {uri}.") 

247 uri_parts = urlsplit(rest_of_the_url) 

248 protocol = uri_parts.scheme if host_with_protocol else None 

249 host = _parse_netloc_to_hostname(uri_parts) 

250 self.host = self._create_host(protocol, host) 

251 quoted_schema = uri_parts.path[1:] 

252 self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema 

253 self.login = unquote(uri_parts.username) if uri_parts.username else uri_parts.username 

254 self.password = unquote(uri_parts.password) if uri_parts.password else uri_parts.password 

255 self.port = uri_parts.port 

256 if uri_parts.query: 

257 query = dict(parse_qsl(uri_parts.query, keep_blank_values=True)) 

258 if self.EXTRA_KEY in query: 

259 self.extra = query[self.EXTRA_KEY] 

260 else: 

261 self.extra = json.dumps(query) 

262 

263 @staticmethod 

264 def _create_host(protocol, host) -> str | None: 

265 """Return the connection host with the protocol.""" 

266 if not host: 

267 return host 

268 if protocol: 

269 return f"{protocol}://{host}" 

270 return host 

271 

272 def get_uri(self) -> str: 

273 """Return connection in URI format.""" 

274 if self.conn_type and "_" in self.conn_type: 

275 self.log.warning( 

276 "Connection schemes (type: %s) shall not contain '_' according to RFC3986.", 

277 self.conn_type, 

278 ) 

279 

280 if self.conn_type: 

281 uri = f"{self.conn_type.lower().replace('_', '-')}://" 

282 else: 

283 uri = "//" 

284 

285 if self.host and "://" in self.host: 

286 protocol, host = self.host.split("://", 1) 

287 else: 

288 protocol, host = None, self.host 

289 

290 if protocol: 

291 uri += f"{protocol}://" 

292 

293 authority_block = "" 

294 if self.login is not None: 

295 authority_block += quote(self.login, safe="") 

296 

297 if self.password is not None: 

298 authority_block += ":" + quote(self.password, safe="") 

299 

300 if authority_block > "": 

301 authority_block += "@" 

302 

303 uri += authority_block 

304 

305 host_block = "" 

306 if host: 

307 host_block += quote(host, safe="") 

308 

309 if self.port: 

310 if host_block == "" and authority_block == "": 

311 host_block += f"@:{self.port}" 

312 else: 

313 host_block += f":{self.port}" 

314 

315 if self.schema: 

316 host_block += f"/{quote(self.schema, safe='')}" 

317 

318 uri += host_block 

319 

320 if self.extra: 

321 try: 

322 query: str | None = urlencode(self.extra_dejson) 

323 except TypeError: 

324 query = None 

325 if query and self.extra_dejson == dict(parse_qsl(query, keep_blank_values=True)): 

326 uri += ("?" if self.schema else "/?") + query 

327 else: 

328 uri += ("?" if self.schema else "/?") + urlencode({self.EXTRA_KEY: self.extra}) 

329 

330 return uri 

331 

332 def get_password(self) -> str | None: 

333 """Return encrypted password.""" 

334 if self._password and self.is_encrypted: 

335 fernet = get_fernet() 

336 if not fernet.is_encrypted: 

337 raise AirflowException( 

338 f"Can't decrypt encrypted password for login={self.login} " 

339 f"FERNET_KEY configuration is missing" 

340 ) 

341 return fernet.decrypt(bytes(self._password, "utf-8")).decode() 

342 else: 

343 return self._password 

344 

345 def set_password(self, value: str | None): 

346 """Encrypt password and set in object attribute.""" 

347 if value: 

348 fernet = get_fernet() 

349 self._password = fernet.encrypt(bytes(value, "utf-8")).decode() 

350 self.is_encrypted = fernet.is_encrypted 

351 

352 @declared_attr 

353 def password(cls): 

354 """Password. The value is decrypted/encrypted when reading/setting the value.""" 

355 return synonym("_password", descriptor=property(cls.get_password, cls.set_password)) 

356 

357 def get_extra(self) -> str: 

358 """Return encrypted extra-data.""" 

359 if self._extra and self.is_extra_encrypted: 

360 fernet = get_fernet() 

361 if not fernet.is_encrypted: 

362 raise AirflowException( 

363 f"Can't decrypt `extra` params for login={self.login}, " 

364 f"FERNET_KEY configuration is missing" 

365 ) 

366 extra_val = fernet.decrypt(bytes(self._extra, "utf-8")).decode() 

367 else: 

368 extra_val = self._extra 

369 if extra_val: 

370 self._validate_extra(extra_val, self.conn_id) 

371 return extra_val 

372 

373 def set_extra(self, value: str): 

374 """Encrypt extra-data and save in object attribute to object.""" 

375 if value: 

376 self._validate_extra(value, self.conn_id) 

377 fernet = get_fernet() 

378 self._extra = fernet.encrypt(bytes(value, "utf-8")).decode() 

379 self.is_extra_encrypted = fernet.is_encrypted 

380 else: 

381 self._extra = value 

382 self.is_extra_encrypted = False 

383 

384 @declared_attr 

385 def extra(cls): 

386 """Extra data. The value is decrypted/encrypted when reading/setting the value.""" 

387 return synonym("_extra", descriptor=property(cls.get_extra, cls.set_extra)) 

388 

389 def rotate_fernet_key(self): 

390 """Encrypts data with a new key. See: :ref:`security/fernet`.""" 

391 fernet = get_fernet() 

392 if self._password and self.is_encrypted: 

393 self._password = fernet.rotate(self._password.encode("utf-8")).decode() 

394 if self._extra and self.is_extra_encrypted: 

395 self._extra = fernet.rotate(self._extra.encode("utf-8")).decode() 

396 

397 def get_hook(self, *, hook_params=None): 

398 """Return hook based on conn_type.""" 

399 from airflow.providers_manager import ProvidersManager 

400 

401 hook = ProvidersManager().hooks.get(self.conn_type, None) 

402 

403 if hook is None: 

404 raise AirflowException(f'Unknown hook type "{self.conn_type}"') 

405 try: 

406 hook_class = import_string(hook.hook_class_name) 

407 except ImportError: 

408 log.error( 

409 "Could not import %s when discovering %s %s", 

410 hook.hook_class_name, 

411 hook.hook_name, 

412 hook.package_name, 

413 ) 

414 raise 

415 if hook_params is None: 

416 hook_params = {} 

417 return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params) 

418 

419 def __repr__(self): 

420 return self.conn_id or "" 

421 

422 def log_info(self): 

423 """ 

424 Read each field individually or use the default representation (`__repr__`). 

425 

426 This method is deprecated. 

427 """ 

428 warnings.warn( 

429 "This method is deprecated. You can read each field individually or " 

430 "use the default representation (__repr__).", 

431 RemovedInAirflow3Warning, 

432 stacklevel=2, 

433 ) 

434 return ( 

435 f"id: {self.conn_id}. Host: {self.host}, Port: {self.port}, Schema: {self.schema}, " 

436 f"Login: {self.login}, Password: {'XXXXXXXX' if self.password else None}, " 

437 f"extra: {'XXXXXXXX' if self.extra_dejson else None}" 

438 ) 

439 

440 def debug_info(self): 

441 """ 

442 Read each field individually or use the default representation (`__repr__`). 

443 

444 This method is deprecated. 

445 """ 

446 warnings.warn( 

447 "This method is deprecated. You can read each field individually or " 

448 "use the default representation (__repr__).", 

449 RemovedInAirflow3Warning, 

450 stacklevel=2, 

451 ) 

452 return ( 

453 f"id: {self.conn_id}. Host: {self.host}, Port: {self.port}, Schema: {self.schema}, " 

454 f"Login: {self.login}, Password: {'XXXXXXXX' if self.password else None}, " 

455 f"extra: {self.extra_dejson}" 

456 ) 

457 

458 def test_connection(self): 

459 """Calls out get_hook method and executes test_connection method on that.""" 

460 status, message = False, "" 

461 try: 

462 hook = self.get_hook() 

463 if getattr(hook, "test_connection", False): 

464 status, message = hook.test_connection() 

465 else: 

466 message = ( 

467 f"Hook {hook.__class__.__name__} doesn't implement or inherit test_connection method" 

468 ) 

469 except Exception as e: 

470 message = str(e) 

471 

472 return status, message 

473 

474 @property 

475 def extra_dejson(self) -> dict: 

476 """Returns the extra property by deserializing json.""" 

477 obj = {} 

478 if self.extra: 

479 try: 

480 obj = json.loads(self.extra) 

481 

482 except JSONDecodeError: 

483 self.log.exception("Failed parsing the json for conn_id %s", self.conn_id) 

484 

485 # Mask sensitive keys from this list 

486 mask_secret(obj) 

487 

488 return obj 

489 

490 @classmethod 

491 def get_connection_from_secrets(cls, conn_id: str) -> Connection: 

492 """ 

493 Get connection by conn_id. 

494 

495 :param conn_id: connection id 

496 :return: connection 

497 """ 

498 # check cache first 

499 # enabled only if SecretCache.init() has been called first 

500 try: 

501 uri = SecretCache.get_connection_uri(conn_id) 

502 return Connection(conn_id=conn_id, uri=uri) 

503 except SecretCache.NotPresentException: 

504 pass # continue business 

505 

506 # iterate over backends if not in cache (or expired) 

507 for secrets_backend in ensure_secrets_loaded(): 

508 try: 

509 conn = secrets_backend.get_connection(conn_id=conn_id) 

510 if conn: 

511 SecretCache.save_connection_uri(conn_id, conn.get_uri()) 

512 return conn 

513 except Exception: 

514 log.exception( 

515 "Unable to retrieve connection from secrets backend (%s). " 

516 "Checking subsequent secrets backend.", 

517 type(secrets_backend).__name__, 

518 ) 

519 

520 raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") 

521 

522 def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[str, Any]: 

523 """ 

524 Convert Connection to json-serializable dictionary. 

525 

526 :param prune_empty: Whether or not remove empty values. 

527 :param validate: Validate dictionary is JSON-serializable 

528 

529 :meta private: 

530 """ 

531 conn = { 

532 "conn_id": self.conn_id, 

533 "conn_type": self.conn_type, 

534 "description": self.description, 

535 "host": self.host, 

536 "login": self.login, 

537 "password": self.password, 

538 "schema": self.schema, 

539 "port": self.port, 

540 } 

541 if prune_empty: 

542 conn = prune_dict(val=conn, mode="strict") 

543 if (extra := self.extra_dejson) or not prune_empty: 

544 conn["extra"] = extra 

545 

546 if validate: 

547 json.dumps(conn) 

548 return conn 

549 

550 @classmethod 

551 def from_json(cls, value, conn_id=None) -> Connection: 

552 kwargs = json.loads(value) 

553 extra = kwargs.pop("extra", None) 

554 if extra: 

555 kwargs["extra"] = extra if isinstance(extra, str) else json.dumps(extra) 

556 conn_type = kwargs.pop("conn_type", None) 

557 if conn_type: 

558 kwargs["conn_type"] = cls._normalize_conn_type(conn_type) 

559 port = kwargs.pop("port", None) 

560 if port: 

561 try: 

562 kwargs["port"] = int(port) 

563 except ValueError: 

564 raise ValueError(f"Expected integer value for `port`, but got {port!r} instead.") 

565 return Connection(conn_id=conn_id, **kwargs) 

566 

567 def as_json(self) -> str: 

568 """Convert Connection to JSON-string object.""" 

569 conn_repr = self.to_dict(prune_empty=True, validate=False) 

570 conn_repr.pop("conn_id", None) 

571 return json.dumps(conn_repr)