1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import logging
20from typing import TYPE_CHECKING, Any
21
22from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin
23
24if TYPE_CHECKING:
25 from airflow.sdk.definitions.connection import Connection
26
27log = logging.getLogger(__name__)
28
29
30class BaseHook(LoggingMixin):
31 """
32 Abstract base class for hooks.
33
34 Hooks are meant as an interface to
35 interact with external systems. MySqlHook, HiveHook, PigHook return
36 object that can handle the connection and interaction to specific
37 instances of these systems, and expose consistent methods to interact
38 with them.
39
40 :param logger_name: Name of the logger used by the Hook to emit logs.
41 If set to `None` (default), the logger name will fall back to
42 `airflow.task.hooks.{class.__module__}.{class.__name__}` (e.g. DbApiHook will have
43 *airflow.task.hooks.airflow.providers.common.sql.hooks.sql.DbApiHook* as logger).
44 """
45
46 def __init__(self, logger_name: str | None = None):
47 super().__init__()
48 self._log_config_logger_name = "airflow.task.hooks"
49 self._logger_name = logger_name
50
51 @classmethod
52 def get_connection(cls, conn_id: str) -> Connection:
53 """
54 Get connection, given connection id.
55
56 :param conn_id: connection id
57 :return: connection
58 """
59 from airflow.sdk.definitions.connection import Connection
60
61 conn = Connection.get(conn_id)
62 log.debug("Connection Retrieved '%s' (via task-sdk)", conn.conn_id)
63 return conn
64
65 @classmethod
66 async def aget_connection(cls, conn_id: str) -> Connection:
67 """
68 Get connection (async), given connection id.
69
70 :param conn_id: connection id
71 :return: connection
72 """
73 from airflow.sdk.definitions.connection import Connection
74
75 conn = await Connection.async_get(conn_id)
76 log.debug("Connection Retrieved '%s' (via task-sdk)", conn.conn_id)
77 return conn
78
79 @classmethod
80 def get_hook(cls, conn_id: str, hook_params: dict | None = None):
81 """
82 Return default hook for this connection id.
83
84 :param conn_id: connection id
85 :param hook_params: hook parameters
86 :return: default hook for this connection
87 """
88 connection = cls.get_connection(conn_id)
89 return connection.get_hook(hook_params=hook_params)
90
91 def get_conn(self) -> Any:
92 """Return connection for the hook."""
93 raise NotImplementedError()
94
95 @classmethod
96 def get_connection_form_widgets(cls) -> dict[str, Any]:
97 return {}
98
99 @classmethod
100 def get_ui_field_behaviour(cls) -> dict[str, Any]:
101 return {}