1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import json
21import logging
22from typing import TYPE_CHECKING, Any
23
24from sqlalchemy import Boolean, Column, Integer, String, Text, delete, select
25from sqlalchemy.dialects.mysql import MEDIUMTEXT
26from sqlalchemy.orm import declared_attr, reconstructor, synonym
27
28from airflow.api_internal.internal_api_call import internal_api_call
29from airflow.configuration import ensure_secrets_loaded
30from airflow.models.base import ID_LEN, Base
31from airflow.models.crypto import get_fernet
32from airflow.secrets.cache import SecretCache
33from airflow.secrets.metastore import MetastoreBackend
34from airflow.utils.log.logging_mixin import LoggingMixin
35from airflow.utils.log.secrets_masker import mask_secret
36from airflow.utils.session import provide_session
37
38if TYPE_CHECKING:
39 from sqlalchemy.orm import Session
40
41log = logging.getLogger(__name__)
42
43
44class Variable(Base, LoggingMixin):
45 """A generic way to store and retrieve arbitrary content or settings as a simple key/value store."""
46
47 __tablename__ = "variable"
48 __NO_DEFAULT_SENTINEL = object()
49
50 id = Column(Integer, primary_key=True)
51 key = Column(String(ID_LEN), unique=True)
52 _val = Column("val", Text().with_variant(MEDIUMTEXT, "mysql"))
53 description = Column(Text)
54 is_encrypted = Column(Boolean, unique=False, default=False)
55
56 def __init__(self, key=None, val=None, description=None):
57 super().__init__()
58 self.key = key
59 self.val = val
60 self.description = description
61
62 @reconstructor
63 def on_db_load(self):
64 if self._val:
65 mask_secret(self.val, self.key)
66
67 def __repr__(self):
68 # Hiding the value
69 return f"{self.key} : {self._val}"
70
71 def get_val(self):
72 """Get Airflow Variable from Metadata DB and decode it using the Fernet Key."""
73 from cryptography.fernet import InvalidToken as InvalidFernetToken
74
75 if self._val is not None and self.is_encrypted:
76 try:
77 fernet = get_fernet()
78 return fernet.decrypt(bytes(self._val, "utf-8")).decode()
79 except InvalidFernetToken:
80 self.log.error("Can't decrypt _val for key=%s, invalid token or value", self.key)
81 return None
82 except Exception:
83 self.log.error("Can't decrypt _val for key=%s, FERNET_KEY configuration missing", self.key)
84 return None
85 else:
86 return self._val
87
88 def set_val(self, value):
89 """Encode the specified value with Fernet Key and store it in Variables Table."""
90 if value is not None:
91 fernet = get_fernet()
92 self._val = fernet.encrypt(bytes(value, "utf-8")).decode()
93 self.is_encrypted = fernet.is_encrypted
94
95 @declared_attr
96 def val(cls):
97 """Get Airflow Variable from Metadata DB and decode it using the Fernet Key."""
98 return synonym("_val", descriptor=property(cls.get_val, cls.set_val))
99
100 @classmethod
101 def setdefault(cls, key, default, description=None, deserialize_json=False):
102 """
103 Return the current value for a key or store the default value and return it.
104
105 Works the same as the Python builtin dict object.
106
107 :param key: Dict key for this Variable
108 :param default: Default value to set and return if the variable
109 isn't already in the DB
110 :param description: Default value to set Description of the Variable
111 :param deserialize_json: Store this as a JSON encoded value in the DB
112 and un-encode it when retrieving a value
113 :return: Mixed
114 """
115 obj = Variable.get(key, default_var=None, deserialize_json=deserialize_json)
116 if obj is None:
117 if default is not None:
118 Variable.set(key, default, description=description, serialize_json=deserialize_json)
119 return default
120 else:
121 raise ValueError("Default Value must be set")
122 else:
123 return obj
124
125 @classmethod
126 def get(
127 cls,
128 key: str,
129 default_var: Any = __NO_DEFAULT_SENTINEL,
130 deserialize_json: bool = False,
131 ) -> Any:
132 """Get a value for an Airflow Variable Key.
133
134 :param key: Variable Key
135 :param default_var: Default value of the Variable if the Variable doesn't exist
136 :param deserialize_json: Deserialize the value to a Python dict
137 """
138 var_val = Variable.get_variable_from_secrets(key=key)
139 if var_val is None:
140 if default_var is not cls.__NO_DEFAULT_SENTINEL:
141 return default_var
142 else:
143 raise KeyError(f"Variable {key} does not exist")
144 else:
145 if deserialize_json:
146 obj = json.loads(var_val)
147 mask_secret(obj, key)
148 return obj
149 else:
150 mask_secret(var_val, key)
151 return var_val
152
153 @staticmethod
154 @provide_session
155 @internal_api_call
156 def set(
157 key: str,
158 value: Any,
159 description: str | None = None,
160 serialize_json: bool = False,
161 session: Session = None,
162 ) -> None:
163 """Set a value for an Airflow Variable with a given Key.
164
165 This operation overwrites an existing variable.
166
167 :param key: Variable Key
168 :param value: Value to set for the Variable
169 :param description: Description of the Variable
170 :param serialize_json: Serialize the value to a JSON string
171 """
172 # check if the secret exists in the custom secrets' backend.
173 Variable.check_for_write_conflict(key)
174 if serialize_json:
175 stored_value = json.dumps(value, indent=2)
176 else:
177 stored_value = str(value)
178
179 Variable.delete(key, session=session)
180 session.add(Variable(key=key, val=stored_value, description=description))
181 session.flush()
182 # invalidate key in cache for faster propagation
183 # we cannot save the value set because it's possible that it's shadowed by a custom backend
184 # (see call to check_for_write_conflict above)
185 SecretCache.invalidate_variable(key)
186
187 @staticmethod
188 @provide_session
189 @internal_api_call
190 def update(
191 key: str,
192 value: Any,
193 serialize_json: bool = False,
194 session: Session = None,
195 ) -> None:
196 """Update a given Airflow Variable with the Provided value.
197
198 :param key: Variable Key
199 :param value: Value to set for the Variable
200 :param serialize_json: Serialize the value to a JSON string
201 """
202 Variable.check_for_write_conflict(key)
203
204 if Variable.get_variable_from_secrets(key=key) is None:
205 raise KeyError(f"Variable {key} does not exist")
206 obj = session.scalar(select(Variable).where(Variable.key == key))
207 if obj is None:
208 raise AttributeError(f"Variable {key} does not exist in the Database and cannot be updated.")
209
210 Variable.set(key, value, description=obj.description, serialize_json=serialize_json)
211
212 @staticmethod
213 @provide_session
214 @internal_api_call
215 def delete(key: str, session: Session = None) -> int:
216 """Delete an Airflow Variable for a given key.
217
218 :param key: Variable Keys
219 """
220 rows = session.execute(delete(Variable).where(Variable.key == key)).rowcount
221 SecretCache.invalidate_variable(key)
222 return rows
223
224 def rotate_fernet_key(self):
225 """Rotate Fernet Key."""
226 fernet = get_fernet()
227 if self._val and self.is_encrypted:
228 self._val = fernet.rotate(self._val.encode("utf-8")).decode()
229
230 @staticmethod
231 def check_for_write_conflict(key: str) -> None:
232 """Log a warning if a variable exists outside the metastore.
233
234 If we try to write a variable to the metastore while the same key
235 exists in an environment variable or custom secrets backend, then
236 subsequent reads will not read the set value.
237
238 :param key: Variable Key
239 """
240 for secrets_backend in ensure_secrets_loaded():
241 if not isinstance(secrets_backend, MetastoreBackend):
242 try:
243 var_val = secrets_backend.get_variable(key=key)
244 if var_val is not None:
245 _backend_name = type(secrets_backend).__name__
246 log.warning(
247 "The variable %s is defined in the %s secrets backend, which takes "
248 "precedence over reading from the database. The value in the database will be "
249 "updated, but to read it you have to delete the conflicting variable "
250 "from %s",
251 key,
252 _backend_name,
253 _backend_name,
254 )
255 return
256 except Exception:
257 log.exception(
258 "Unable to retrieve variable from secrets backend (%s). "
259 "Checking subsequent secrets backend.",
260 type(secrets_backend).__name__,
261 )
262 return None
263
264 @staticmethod
265 def get_variable_from_secrets(key: str) -> str | None:
266 """
267 Get Airflow Variable by iterating over all Secret Backends.
268
269 :param key: Variable Key
270 :return: Variable Value
271 """
272 # check cache first
273 # enabled only if SecretCache.init() has been called first
274 try:
275 return SecretCache.get_variable(key)
276 except SecretCache.NotPresentException:
277 pass # continue business
278
279 var_val = None
280 # iterate over backends if not in cache (or expired)
281 for secrets_backend in ensure_secrets_loaded():
282 try:
283 var_val = secrets_backend.get_variable(key=key)
284 if var_val is not None:
285 break
286 except Exception:
287 log.exception(
288 "Unable to retrieve variable from secrets backend (%s). "
289 "Checking subsequent secrets backend.",
290 type(secrets_backend).__name__,
291 )
292
293 SecretCache.save_variable(key, var_val) # we save None as well
294 return var_val