1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import contextlib
21import json
22import logging
23import sys
24import warnings
25from typing import TYPE_CHECKING, Any
26
27from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, delete, or_, select
28from sqlalchemy.dialects.mysql import MEDIUMTEXT
29from sqlalchemy.orm import Mapped, declared_attr, reconstructor, synonym
30
31from airflow._shared.secrets_masker import mask_secret
32from airflow.configuration import conf, ensure_secrets_loaded
33from airflow.models.base import ID_LEN, Base
34from airflow.models.crypto import get_fernet
35from airflow.secrets.metastore import MetastoreBackend
36from airflow.utils.log.logging_mixin import LoggingMixin
37from airflow.utils.session import NEW_SESSION, create_session, provide_session
38from airflow.utils.sqlalchemy import get_dialect_name, mapped_column
39
40if TYPE_CHECKING:
41 from sqlalchemy.dialects.mysql.dml import Insert as MySQLInsert
42 from sqlalchemy.dialects.postgresql.dml import Insert as PostgreSQLInsert
43 from sqlalchemy.dialects.sqlite.dml import Insert as SQLiteInsert
44 from sqlalchemy.orm import Session
45
46log = logging.getLogger(__name__)
47
48
49class Variable(Base, LoggingMixin):
50 """A generic way to store and retrieve arbitrary content or settings as a simple key/value store."""
51
52 __tablename__ = "variable"
53 __NO_DEFAULT_SENTINEL = object()
54
55 id: Mapped[int] = mapped_column(Integer, primary_key=True)
56 key: Mapped[str] = mapped_column(String(ID_LEN), unique=True)
57 _val: Mapped[str] = mapped_column("val", Text().with_variant(MEDIUMTEXT, "mysql"))
58 description: Mapped[str | None] = mapped_column(Text, nullable=True)
59 is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False)
60 team_name: Mapped[str | None] = mapped_column(
61 String(50),
62 ForeignKey("team.name", ondelete="SET NULL"),
63 nullable=True,
64 )
65
66 def __init__(self, key=None, val=None, description=None, team_name=None):
67 super().__init__()
68 self.key = key
69 self.val = val
70 self.description = description
71 self.team_name = team_name
72
73 @reconstructor
74 def on_db_load(self):
75 if self._val:
76 mask_secret(self.val, self.key)
77
78 def __repr__(self):
79 # Hiding the value
80 return f"{self.key} : {self._val}"
81
82 def get_val(self):
83 """Get Airflow Variable from Metadata DB and decode it using the Fernet Key."""
84 from cryptography.fernet import InvalidToken as InvalidFernetToken
85
86 if self._val is not None and self.is_encrypted:
87 try:
88 fernet = get_fernet()
89 return fernet.decrypt(bytes(self._val, "utf-8")).decode()
90 except InvalidFernetToken:
91 self.log.error("Can't decrypt _val for key=%s, invalid token or value", self.key)
92 return None
93 except Exception:
94 self.log.error("Can't decrypt _val for key=%s, FERNET_KEY configuration missing", self.key)
95 return None
96 else:
97 return self._val
98
99 def set_val(self, value):
100 """Encode the specified value with Fernet Key and store it in Variables Table."""
101 if value is not None:
102 fernet = get_fernet()
103 self._val = fernet.encrypt(bytes(value, "utf-8")).decode()
104 self.is_encrypted = fernet.is_encrypted
105
106 @declared_attr
107 def val(cls):
108 """Get Airflow Variable from Metadata DB and decode it using the Fernet Key."""
109 return synonym("_val", descriptor=property(cls.get_val, cls.set_val))
110
111 @classmethod
112 def setdefault(cls, key, default, description=None, deserialize_json=False):
113 """
114 Return the current value for a key or store the default value and return it.
115
116 Works the same as the Python builtin dict object.
117
118 :param key: Dict key for this Variable
119 :param default: Default value to set and return if the variable
120 isn't already in the DB
121 :param description: Default value to set Description of the Variable
122 :param deserialize_json: Store this as a JSON encoded value in the DB
123 and un-encode it when retrieving a value
124 :param session: Session
125 :return: Mixed
126 """
127 obj = Variable.get(key, default_var=None, deserialize_json=deserialize_json)
128 if obj is None:
129 if default is not None:
130 Variable.set(key=key, value=default, description=description, serialize_json=deserialize_json)
131 return default
132 raise ValueError("Default Value must be set")
133 return obj
134
135 @classmethod
136 def get(
137 cls,
138 key: str,
139 default_var: Any = __NO_DEFAULT_SENTINEL,
140 deserialize_json: bool = False,
141 team_name: str | None = None,
142 ) -> Any:
143 """
144 Get a value for an Airflow Variable Key.
145
146 :param key: Variable Key
147 :param default_var: Default value of the Variable if the Variable doesn't exist
148 :param deserialize_json: Deserialize the value to a Python dict
149 :param team_name: Team name associated to the task trying to access the variable (if any)
150 """
151 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
152 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
153 # back-compat layer
154
155 # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
156 # and should use the Task SDK API server path
157 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
158 warnings.warn(
159 "Using Variable.get from `airflow.models` is deprecated."
160 "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead",
161 DeprecationWarning,
162 stacklevel=1,
163 )
164 from airflow.sdk import Variable as TaskSDKVariable
165
166 default_kwargs = {} if default_var is cls.__NO_DEFAULT_SENTINEL else {"default": default_var}
167 var_val = TaskSDKVariable.get(key, deserialize_json=deserialize_json, **default_kwargs)
168 if isinstance(var_val, str):
169 mask_secret(var_val, key)
170
171 return var_val
172
173 if team_name and not conf.getboolean("core", "multi_team"):
174 raise ValueError(
175 "Multi-team mode is not configured in the Airflow environment but the task trying to access the variable belongs to a team"
176 )
177
178 var_val = Variable.get_variable_from_secrets(key=key, team_name=team_name)
179 if var_val is None:
180 if default_var is not cls.__NO_DEFAULT_SENTINEL:
181 return default_var
182 raise KeyError(f"Variable {key} does not exist")
183 if deserialize_json:
184 obj = json.loads(var_val)
185 mask_secret(obj, key)
186 return obj
187 mask_secret(var_val, key)
188 return var_val
189
190 @staticmethod
191 def set(
192 key: str,
193 value: Any,
194 description: str | None = None,
195 serialize_json: bool = False,
196 team_name: str | None = None,
197 session: Session | None = None,
198 ) -> None:
199 """
200 Set a value for an Airflow Variable with a given Key.
201
202 This operation overwrites an existing variable using the session's dialect-specific upsert operation.
203
204 :param key: Variable Key
205 :param value: Value to set for the Variable
206 :param description: Description of the Variable
207 :param serialize_json: Serialize the value to a JSON string
208 :param team_name: Team name associated to the variable (if any)
209 :param session: optional session, use if provided or create a new one
210 """
211 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
212 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
213 # back-compat layer
214
215 # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
216 # and should use the Task SDK API server path
217 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
218 warnings.warn(
219 "Using Variable.set from `airflow.models` is deprecated."
220 "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead",
221 DeprecationWarning,
222 stacklevel=1,
223 )
224 from airflow.sdk import Variable as TaskSDKVariable
225
226 TaskSDKVariable.set(
227 key=key,
228 value=value,
229 description=description,
230 serialize_json=serialize_json,
231 )
232 return
233
234 if team_name and not conf.getboolean("core", "multi_team"):
235 raise ValueError(
236 "Multi-team mode is not configured in the Airflow environment. To assign a team to a variable, multi-mode must be enabled."
237 )
238
239 # check if the secret exists in the custom secrets' backend.
240 from airflow.sdk import SecretCache
241
242 Variable.check_for_write_conflict(key=key)
243 if serialize_json:
244 stored_value = json.dumps(value, indent=2)
245 else:
246 stored_value = str(value)
247
248 ctx: contextlib.AbstractContextManager
249 if session is not None:
250 ctx = contextlib.nullcontext(session)
251 else:
252 ctx = create_session()
253
254 with ctx as session:
255 new_variable = Variable(key=key, val=stored_value, description=description, team_name=team_name)
256
257 val = new_variable._val
258 is_encrypted = new_variable.is_encrypted
259
260 # Create dialect-specific upsert statement
261 dialect_name = get_dialect_name(session)
262 stmt: MySQLInsert | PostgreSQLInsert | SQLiteInsert
263
264 if dialect_name == "postgresql":
265 from sqlalchemy.dialects.postgresql import insert as pg_insert
266
267 pg_stmt = pg_insert(Variable).values(
268 key=key,
269 val=val,
270 description=description,
271 is_encrypted=is_encrypted,
272 team_name=team_name,
273 )
274 stmt = pg_stmt.on_conflict_do_update(
275 index_elements=["key"],
276 set_=dict(
277 val=val,
278 description=description,
279 is_encrypted=is_encrypted,
280 team_name=team_name,
281 ),
282 )
283 elif dialect_name == "mysql":
284 from sqlalchemy.dialects.mysql import insert as mysql_insert
285
286 mysql_stmt = mysql_insert(Variable).values(
287 key=key,
288 val=val,
289 description=description,
290 is_encrypted=is_encrypted,
291 team_name=team_name,
292 )
293 stmt = mysql_stmt.on_duplicate_key_update(
294 val=val,
295 description=description,
296 is_encrypted=is_encrypted,
297 team_name=team_name,
298 )
299 else:
300 from sqlalchemy.dialects.sqlite import insert as sqlite_insert
301
302 sqlite_stmt = sqlite_insert(Variable).values(
303 key=key,
304 val=val,
305 description=description,
306 is_encrypted=is_encrypted,
307 team_name=team_name,
308 )
309 stmt = sqlite_stmt.on_conflict_do_update(
310 index_elements=["key"],
311 set_=dict(
312 val=val,
313 description=description,
314 is_encrypted=is_encrypted,
315 team_name=team_name,
316 ),
317 )
318
319 session.execute(stmt)
320 # invalidate key in cache for faster propagation
321 # we cannot save the value set because it's possible that it's shadowed by a custom backend
322 # (see call to check_for_write_conflict above)
323 SecretCache.invalidate_variable(key)
324
325 @staticmethod
326 def update(
327 key: str,
328 value: Any,
329 serialize_json: bool = False,
330 team_name: str | None = None,
331 session: Session | None = None,
332 ) -> None:
333 """
334 Update a given Airflow Variable with the Provided value.
335
336 :param key: Variable Key
337 :param value: Value to set for the Variable
338 :param serialize_json: Serialize the value to a JSON string
339 :param team_name: Team name associated to the variable (if any)
340 :param session: optional session, use if provided or create a new one
341 """
342 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
343 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
344 # back-compat layer
345
346 # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
347 # and should use the Task SDK API server path
348 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
349 warnings.warn(
350 "Using Variable.update from `airflow.models` is deprecated."
351 "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.",
352 DeprecationWarning,
353 stacklevel=1,
354 )
355 from airflow.sdk import Variable as TaskSDKVariable
356
357 # set is an upsert command, it can handle updates too
358 TaskSDKVariable.set(
359 key=key,
360 value=value,
361 serialize_json=serialize_json,
362 )
363 return
364
365 if team_name and not conf.getboolean("core", "multi_team"):
366 raise ValueError(
367 "Multi-team mode is not configured in the Airflow environment. To assign a team to a variable, multi-mode must be enabled."
368 )
369
370 Variable.check_for_write_conflict(key=key)
371
372 if Variable.get_variable_from_secrets(key=key, team_name=team_name) is None:
373 raise KeyError(f"Variable {key} does not exist")
374
375 ctx: contextlib.AbstractContextManager
376 if session is not None:
377 ctx = contextlib.nullcontext(session)
378 else:
379 ctx = create_session()
380
381 with ctx as session:
382 obj = session.scalar(
383 select(Variable).where(
384 Variable.key == key, or_(Variable.team_name == team_name, Variable.team_name.is_(None))
385 )
386 )
387 if obj is None:
388 raise AttributeError(f"Variable {key} does not exist in the Database and cannot be updated.")
389
390 Variable.set(
391 key=key,
392 value=value,
393 description=obj.description,
394 serialize_json=serialize_json,
395 session=session,
396 )
397
398 @staticmethod
399 def delete(key: str, team_name: str | None = None, session: Session | None = None) -> int:
400 """
401 Delete an Airflow Variable for a given key.
402
403 :param key: Variable Keys
404 :param team_name: Team name associated to the task trying to delete the variable (if any)
405 :param session: optional session, use if provided or create a new one
406 """
407 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
408 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
409 # back-compat layer
410
411 # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
412 # and should use the Task SDK API server path
413 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
414 warnings.warn(
415 "Using Variable.delete from `airflow.models` is deprecated."
416 "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead",
417 DeprecationWarning,
418 stacklevel=1,
419 )
420 from airflow.sdk import Variable as TaskSDKVariable
421
422 TaskSDKVariable.delete(
423 key=key,
424 )
425 return 1
426
427 if team_name and not conf.getboolean("core", "multi_team"):
428 raise ValueError(
429 "Multi-team mode is not configured in the Airflow environment but the task trying to delete the variable belongs to a team"
430 )
431
432 from airflow.sdk import SecretCache
433
434 ctx: contextlib.AbstractContextManager
435 if session is not None:
436 ctx = contextlib.nullcontext(session)
437 else:
438 ctx = create_session()
439
440 with ctx as session:
441 result = session.execute(
442 delete(Variable).where(
443 Variable.key == key, or_(Variable.team_name == team_name, Variable.team_name.is_(None))
444 )
445 )
446 rows = getattr(result, "rowcount", 0) or 0
447 SecretCache.invalidate_variable(key)
448 return rows
449
450 def rotate_fernet_key(self):
451 """Rotate Fernet Key."""
452 fernet = get_fernet()
453 if self._val and self.is_encrypted:
454 self._val = fernet.rotate(self._val.encode("utf-8")).decode()
455
456 @staticmethod
457 def check_for_write_conflict(key: str) -> None:
458 """
459 Log a warning if a variable exists outside the metastore.
460
461 If we try to write a variable to the metastore while the same key
462 exists in an environment variable or custom secrets backend, then
463 subsequent reads will not read the set value.
464
465 :param key: Variable Key
466 """
467 for secrets_backend in ensure_secrets_loaded():
468 if not isinstance(secrets_backend, MetastoreBackend):
469 try:
470 var_val = secrets_backend.get_variable(key=key)
471 if var_val is not None:
472 _backend_name = type(secrets_backend).__name__
473 log.warning(
474 "The variable %s is defined in the %s secrets backend, which takes "
475 "precedence over reading from the database. The value in the database will be "
476 "updated, but to read it you have to delete the conflicting variable "
477 "from %s",
478 key,
479 _backend_name,
480 _backend_name,
481 )
482 return
483 except Exception:
484 log.exception(
485 "Unable to retrieve variable from secrets backend (%s). "
486 "Checking subsequent secrets backend.",
487 type(secrets_backend).__name__,
488 )
489 return None
490
491 @staticmethod
492 def get_variable_from_secrets(key: str, team_name: str | None = None) -> str | None:
493 """
494 Get Airflow Variable by iterating over all Secret Backends.
495
496 :param key: Variable Key
497 :param team_name: Team name associated to the task trying to access the variable (if any)
498 :return: Variable Value
499 """
500 from airflow.sdk import SecretCache
501
502 # Disable cache if the variable belongs to a team. We might enable it later
503 if not team_name:
504 # check cache first
505 # enabled only if SecretCache.init() has been called first
506 try:
507 return SecretCache.get_variable(key)
508 except SecretCache.NotPresentException:
509 pass # continue business
510
511 var_val = None
512 # iterate over backends if not in cache (or expired)
513 for secrets_backend in ensure_secrets_loaded():
514 try:
515 var_val = secrets_backend.get_variable(key=key, team_name=team_name)
516 if var_val is not None:
517 break
518 except Exception:
519 log.exception(
520 "Unable to retrieve variable from secrets backend (%s). "
521 "Checking subsequent secrets backend.",
522 type(secrets_backend).__name__,
523 )
524
525 SecretCache.save_variable(key, var_val) # we save None as well
526 return var_val
527
528 @staticmethod
529 @provide_session
530 def get_team_name(variable_key: str, session=NEW_SESSION) -> str | None:
531 stmt = select(Variable.team_name).where(Variable.key == variable_key)
532 return session.scalar(stmt)
533
534 @staticmethod
535 @provide_session
536 def get_key_to_team_name_mapping(variable_keys: list[str], session=NEW_SESSION) -> dict[str, str | None]:
537 stmt = select(Variable.key, Variable.team_name).where(Variable.key.in_(variable_keys))
538 return {key: team_name for key, team_name in session.execute(stmt)}