1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19from abc import ABC
20
21
22class BaseSecretsBackend(ABC):
23 """Abstract base class to retrieve Connection object given a conn_id or Variable given a key."""
24
25 @staticmethod
26 def build_path(path_prefix: str, secret_id: str, sep: str = "/") -> str:
27 """
28 Given conn_id, build path for Secrets Backend.
29
30 :param path_prefix: Prefix of the path to get secret
31 :param secret_id: Secret id
32 :param sep: separator used to concatenate connections_prefix and conn_id. Default: "/"
33 """
34 return f"{path_prefix}{sep}{secret_id}"
35
36 def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None:
37 """
38 Retrieve from Secrets Backend a string value representing the Connection object.
39
40 If the client your secrets backend uses already returns a python dict, you should override
41 ``get_connection`` instead.
42
43 :param conn_id: connection id
44 :param team_name: Team name associated to the task trying to access the connection (if any)
45 """
46 raise NotImplementedError
47
48 def get_variable(self, key: str, team_name: str | None = None) -> str | None:
49 """
50 Return value for Airflow Variable.
51
52 :param key: Variable Key
53 :param team_name: Team name associated to the task trying to access the variable (if any)
54 :return: Variable Value
55 """
56 raise NotImplementedError()
57
58 def get_config(self, key: str) -> str | None:
59 """
60 Return value for Airflow Config Key.
61
62 :param key: Config Key
63 :return: Config Value
64 """
65 return None
66
67 @staticmethod
68 def _get_connection_class():
69 """
70 Detect which Connection class to use based on execution context.
71
72 Returns SDK Connection in worker context, core Connection in server context.
73 """
74 import os
75
76 process_context = os.environ.get("_AIRFLOW_PROCESS_CONTEXT", "").lower()
77 if process_context == "client":
78 # Client context (worker, dag processor, triggerer)
79 from airflow.sdk.definitions.connection import Connection
80
81 return Connection
82
83 # Server context (scheduler, API server, etc.)
84 from airflow.models.connection import Connection
85
86 return Connection
87
88 def deserialize_connection(self, conn_id: str, value: str):
89 """
90 Given a serialized representation of the airflow Connection, return an instance.
91
92 Auto-detects which Connection class to use based on execution context.
93 Uses Connection.from_json() for JSON format, Connection(uri=...) for URI format.
94
95 :param conn_id: connection id
96 :param value: the serialized representation of the Connection object
97 :return: the deserialized Connection
98 """
99 conn_class = self._get_connection_class()
100
101 value = value.strip()
102 if value[0] == "{":
103 return conn_class.from_json(value=value, conn_id=conn_id)
104
105 # TODO: Only sdk has from_uri defined on it. Is it worthwhile developing the core path or not?
106 if hasattr(conn_class, "from_uri"):
107 return conn_class.from_uri(conn_id=conn_id, uri=value)
108 return conn_class(conn_id=conn_id, uri=value)
109
110 def get_connection(self, conn_id: str, team_name: str | None = None):
111 """
112 Return connection object with a given ``conn_id``.
113
114 :param conn_id: connection id
115 :param team_name: Team name associated to the task trying to access the connection (if any)
116 :return: Connection object or None
117 """
118 value = self.get_conn_value(conn_id=conn_id, team_name=team_name)
119 if value:
120 return self.deserialize_connection(conn_id=conn_id, value=value)
121 return None