1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import contextlib
21import json
22import logging
23import sys
24import warnings
25from typing import TYPE_CHECKING, Any
26
27from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, delete, or_, select
28from sqlalchemy.dialects.mysql import MEDIUMTEXT
29from sqlalchemy.orm import Mapped, declared_attr, reconstructor, synonym
30
31from airflow._shared.secrets_masker import mask_secret
32from airflow.configuration import conf, ensure_secrets_loaded
33from airflow.models.base import ID_LEN, Base
34from airflow.models.crypto import get_fernet
35from airflow.sdk import SecretCache
36from airflow.secrets.metastore import MetastoreBackend
37from airflow.utils.log.logging_mixin import LoggingMixin
38from airflow.utils.session import NEW_SESSION, create_session, provide_session
39from airflow.utils.sqlalchemy import get_dialect_name, mapped_column
40
41if TYPE_CHECKING:
42 from sqlalchemy.dialects.mysql.dml import Insert as MySQLInsert
43 from sqlalchemy.dialects.postgresql.dml import Insert as PostgreSQLInsert
44 from sqlalchemy.dialects.sqlite.dml import Insert as SQLiteInsert
45 from sqlalchemy.orm import Session
46
47log = logging.getLogger(__name__)
48
49
50class Variable(Base, LoggingMixin):
51 """A generic way to store and retrieve arbitrary content or settings as a simple key/value store."""
52
53 __tablename__ = "variable"
54 __NO_DEFAULT_SENTINEL = object()
55
56 id: Mapped[int] = mapped_column(Integer, primary_key=True)
57 key: Mapped[str] = mapped_column(String(ID_LEN), unique=True)
58 _val: Mapped[str] = mapped_column("val", Text().with_variant(MEDIUMTEXT, "mysql"))
59 description: Mapped[str | None] = mapped_column(Text, nullable=True)
60 is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False)
61 team_name: Mapped[str | None] = mapped_column(
62 String(50),
63 ForeignKey("team.name", ondelete="SET NULL"),
64 nullable=True,
65 )
66
67 def __init__(self, key=None, val=None, description=None, team_name=None):
68 super().__init__()
69 self.key = key
70 self.val = val
71 self.description = description
72 self.team_name = team_name
73
74 @reconstructor
75 def on_db_load(self):
76 if self._val:
77 mask_secret(self.val, self.key)
78
79 def __repr__(self):
80 # Hiding the value
81 return f"{self.key} : {self._val}"
82
83 def get_val(self):
84 """Get Airflow Variable from Metadata DB and decode it using the Fernet Key."""
85 from cryptography.fernet import InvalidToken as InvalidFernetToken
86
87 if self._val is not None and self.is_encrypted:
88 try:
89 fernet = get_fernet()
90 return fernet.decrypt(bytes(self._val, "utf-8")).decode()
91 except InvalidFernetToken:
92 self.log.error("Can't decrypt _val for key=%s, invalid token or value", self.key)
93 return None
94 except Exception:
95 self.log.error("Can't decrypt _val for key=%s, FERNET_KEY configuration missing", self.key)
96 return None
97 else:
98 return self._val
99
100 def set_val(self, value):
101 """Encode the specified value with Fernet Key and store it in Variables Table."""
102 if value is not None:
103 fernet = get_fernet()
104 self._val = fernet.encrypt(bytes(value, "utf-8")).decode()
105 self.is_encrypted = fernet.is_encrypted
106
107 @declared_attr
108 def val(cls):
109 """Get Airflow Variable from Metadata DB and decode it using the Fernet Key."""
110 return synonym("_val", descriptor=property(cls.get_val, cls.set_val))
111
112 @classmethod
113 def setdefault(cls, key, default, description=None, deserialize_json=False):
114 """
115 Return the current value for a key or store the default value and return it.
116
117 Works the same as the Python builtin dict object.
118
119 :param key: Dict key for this Variable
120 :param default: Default value to set and return if the variable
121 isn't already in the DB
122 :param description: Default value to set Description of the Variable
123 :param deserialize_json: Store this as a JSON encoded value in the DB
124 and un-encode it when retrieving a value
125 :param session: Session
126 :return: Mixed
127 """
128 obj = Variable.get(key, default_var=None, deserialize_json=deserialize_json)
129 if obj is None:
130 if default is not None:
131 Variable.set(key=key, value=default, description=description, serialize_json=deserialize_json)
132 return default
133 raise ValueError("Default Value must be set")
134 return obj
135
136 @classmethod
137 def get(
138 cls,
139 key: str,
140 default_var: Any = __NO_DEFAULT_SENTINEL,
141 deserialize_json: bool = False,
142 team_name: str | None = None,
143 ) -> Any:
144 """
145 Get a value for an Airflow Variable Key.
146
147 :param key: Variable Key
148 :param default_var: Default value of the Variable if the Variable doesn't exist
149 :param deserialize_json: Deserialize the value to a Python dict
150 :param team_name: Team name associated to the task trying to access the variable (if any)
151 """
152 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
153 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
154 # back-compat layer
155
156 # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
157 # and should use the Task SDK API server path
158 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
159 warnings.warn(
160 "Using Variable.get from `airflow.models` is deprecated."
161 "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead",
162 DeprecationWarning,
163 stacklevel=1,
164 )
165 from airflow.sdk import Variable as TaskSDKVariable
166
167 default_kwargs = {} if default_var is cls.__NO_DEFAULT_SENTINEL else {"default": default_var}
168 var_val = TaskSDKVariable.get(key, deserialize_json=deserialize_json, **default_kwargs)
169 if isinstance(var_val, str):
170 mask_secret(var_val, key)
171
172 return var_val
173
174 if team_name and not conf.getboolean("core", "multi_team"):
175 raise ValueError(
176 "Multi-team mode is not configured in the Airflow environment but the task trying to access the variable belongs to a team"
177 )
178
179 var_val = Variable.get_variable_from_secrets(key=key, team_name=team_name)
180 if var_val is None:
181 if default_var is not cls.__NO_DEFAULT_SENTINEL:
182 return default_var
183 raise KeyError(f"Variable {key} does not exist")
184 if deserialize_json:
185 obj = json.loads(var_val)
186 mask_secret(obj, key)
187 return obj
188 mask_secret(var_val, key)
189 return var_val
190
191 @staticmethod
192 def set(
193 key: str,
194 value: Any,
195 description: str | None = None,
196 serialize_json: bool = False,
197 team_name: str | None = None,
198 session: Session | None = None,
199 ) -> None:
200 """
201 Set a value for an Airflow Variable with a given Key.
202
203 This operation overwrites an existing variable using the session's dialect-specific upsert operation.
204
205 :param key: Variable Key
206 :param value: Value to set for the Variable
207 :param description: Description of the Variable
208 :param serialize_json: Serialize the value to a JSON string
209 :param team_name: Team name associated to the variable (if any)
210 :param session: optional session, use if provided or create a new one
211 """
212 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
213 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
214 # back-compat layer
215
216 # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
217 # and should use the Task SDK API server path
218 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
219 warnings.warn(
220 "Using Variable.set from `airflow.models` is deprecated."
221 "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead",
222 DeprecationWarning,
223 stacklevel=1,
224 )
225 from airflow.sdk import Variable as TaskSDKVariable
226
227 TaskSDKVariable.set(
228 key=key,
229 value=value,
230 description=description,
231 serialize_json=serialize_json,
232 )
233 return
234
235 if team_name and not conf.getboolean("core", "multi_team"):
236 raise ValueError(
237 "Multi-team mode is not configured in the Airflow environment. To assign a team to a variable, multi-mode must be enabled."
238 )
239
240 # check if the secret exists in the custom secrets' backend.
241 Variable.check_for_write_conflict(key=key)
242 if serialize_json:
243 stored_value = json.dumps(value, indent=2)
244 else:
245 stored_value = str(value)
246
247 ctx: contextlib.AbstractContextManager
248 if session is not None:
249 ctx = contextlib.nullcontext(session)
250 else:
251 ctx = create_session()
252
253 with ctx as session:
254 new_variable = Variable(key=key, val=stored_value, description=description, team_name=team_name)
255
256 val = new_variable._val
257 is_encrypted = new_variable.is_encrypted
258
259 # Create dialect-specific upsert statement
260 dialect_name = get_dialect_name(session)
261 stmt: MySQLInsert | PostgreSQLInsert | SQLiteInsert
262
263 if dialect_name == "postgresql":
264 from sqlalchemy.dialects.postgresql import insert as pg_insert
265
266 pg_stmt = pg_insert(Variable).values(
267 key=key,
268 val=val,
269 description=description,
270 is_encrypted=is_encrypted,
271 team_name=team_name,
272 )
273 stmt = pg_stmt.on_conflict_do_update(
274 index_elements=["key"],
275 set_=dict(
276 val=val,
277 description=description,
278 is_encrypted=is_encrypted,
279 team_name=team_name,
280 ),
281 )
282 elif dialect_name == "mysql":
283 from sqlalchemy.dialects.mysql import insert as mysql_insert
284
285 mysql_stmt = mysql_insert(Variable).values(
286 key=key,
287 val=val,
288 description=description,
289 is_encrypted=is_encrypted,
290 team_name=team_name,
291 )
292 stmt = mysql_stmt.on_duplicate_key_update(
293 val=val,
294 description=description,
295 is_encrypted=is_encrypted,
296 team_name=team_name,
297 )
298 else:
299 from sqlalchemy.dialects.sqlite import insert as sqlite_insert
300
301 sqlite_stmt = sqlite_insert(Variable).values(
302 key=key,
303 val=val,
304 description=description,
305 is_encrypted=is_encrypted,
306 team_name=team_name,
307 )
308 stmt = sqlite_stmt.on_conflict_do_update(
309 index_elements=["key"],
310 set_=dict(
311 val=val,
312 description=description,
313 is_encrypted=is_encrypted,
314 team_name=team_name,
315 ),
316 )
317
318 session.execute(stmt)
319 # invalidate key in cache for faster propagation
320 # we cannot save the value set because it's possible that it's shadowed by a custom backend
321 # (see call to check_for_write_conflict above)
322 SecretCache.invalidate_variable(key)
323
324 @staticmethod
325 def update(
326 key: str,
327 value: Any,
328 serialize_json: bool = False,
329 team_name: str | None = None,
330 session: Session | None = None,
331 ) -> None:
332 """
333 Update a given Airflow Variable with the Provided value.
334
335 :param key: Variable Key
336 :param value: Value to set for the Variable
337 :param serialize_json: Serialize the value to a JSON string
338 :param team_name: Team name associated to the variable (if any)
339 :param session: optional session, use if provided or create a new one
340 """
341 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
342 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
343 # back-compat layer
344
345 # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
346 # and should use the Task SDK API server path
347 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
348 warnings.warn(
349 "Using Variable.update from `airflow.models` is deprecated."
350 "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.",
351 DeprecationWarning,
352 stacklevel=1,
353 )
354 from airflow.sdk import Variable as TaskSDKVariable
355
356 # set is an upsert command, it can handle updates too
357 TaskSDKVariable.set(
358 key=key,
359 value=value,
360 serialize_json=serialize_json,
361 )
362 return
363
364 if team_name and not conf.getboolean("core", "multi_team"):
365 raise ValueError(
366 "Multi-team mode is not configured in the Airflow environment. To assign a team to a variable, multi-mode must be enabled."
367 )
368
369 Variable.check_for_write_conflict(key=key)
370
371 if Variable.get_variable_from_secrets(key=key, team_name=team_name) is None:
372 raise KeyError(f"Variable {key} does not exist")
373
374 ctx: contextlib.AbstractContextManager
375 if session is not None:
376 ctx = contextlib.nullcontext(session)
377 else:
378 ctx = create_session()
379
380 with ctx as session:
381 obj = session.scalar(
382 select(Variable).where(
383 Variable.key == key, or_(Variable.team_name == team_name, Variable.team_name.is_(None))
384 )
385 )
386 if obj is None:
387 raise AttributeError(f"Variable {key} does not exist in the Database and cannot be updated.")
388
389 Variable.set(
390 key=key,
391 value=value,
392 description=obj.description,
393 serialize_json=serialize_json,
394 session=session,
395 )
396
397 @staticmethod
398 def delete(key: str, team_name: str | None = None, session: Session | None = None) -> int:
399 """
400 Delete an Airflow Variable for a given key.
401
402 :param key: Variable Keys
403 :param team_name: Team name associated to the task trying to delete the variable (if any)
404 :param session: optional session, use if provided or create a new one
405 """
406 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
407 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
408 # back-compat layer
409
410 # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
411 # and should use the Task SDK API server path
412 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
413 warnings.warn(
414 "Using Variable.delete from `airflow.models` is deprecated."
415 "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead",
416 DeprecationWarning,
417 stacklevel=1,
418 )
419 from airflow.sdk import Variable as TaskSDKVariable
420
421 TaskSDKVariable.delete(
422 key=key,
423 )
424 return 1
425
426 if team_name and not conf.getboolean("core", "multi_team"):
427 raise ValueError(
428 "Multi-team mode is not configured in the Airflow environment but the task trying to delete the variable belongs to a team"
429 )
430
431 ctx: contextlib.AbstractContextManager
432 if session is not None:
433 ctx = contextlib.nullcontext(session)
434 else:
435 ctx = create_session()
436
437 with ctx as session:
438 result = session.execute(
439 delete(Variable).where(
440 Variable.key == key, or_(Variable.team_name == team_name, Variable.team_name.is_(None))
441 )
442 )
443 rows = getattr(result, "rowcount", 0) or 0
444 SecretCache.invalidate_variable(key)
445 return rows
446
447 def rotate_fernet_key(self):
448 """Rotate Fernet Key."""
449 fernet = get_fernet()
450 if self._val and self.is_encrypted:
451 self._val = fernet.rotate(self._val.encode("utf-8")).decode()
452
453 @staticmethod
454 def check_for_write_conflict(key: str) -> None:
455 """
456 Log a warning if a variable exists outside the metastore.
457
458 If we try to write a variable to the metastore while the same key
459 exists in an environment variable or custom secrets backend, then
460 subsequent reads will not read the set value.
461
462 :param key: Variable Key
463 """
464 for secrets_backend in ensure_secrets_loaded():
465 if not isinstance(secrets_backend, MetastoreBackend):
466 try:
467 var_val = secrets_backend.get_variable(key=key)
468 if var_val is not None:
469 _backend_name = type(secrets_backend).__name__
470 log.warning(
471 "The variable %s is defined in the %s secrets backend, which takes "
472 "precedence over reading from the database. The value in the database will be "
473 "updated, but to read it you have to delete the conflicting variable "
474 "from %s",
475 key,
476 _backend_name,
477 _backend_name,
478 )
479 return
480 except Exception:
481 log.exception(
482 "Unable to retrieve variable from secrets backend (%s). "
483 "Checking subsequent secrets backend.",
484 type(secrets_backend).__name__,
485 )
486 return None
487
488 @staticmethod
489 def get_variable_from_secrets(key: str, team_name: str | None = None) -> str | None:
490 """
491 Get Airflow Variable by iterating over all Secret Backends.
492
493 :param key: Variable Key
494 :param team_name: Team name associated to the task trying to access the variable (if any)
495 :return: Variable Value
496 """
497 # Disable cache if the variable belongs to a team. We might enable it later
498 if not team_name:
499 # check cache first
500 # enabled only if SecretCache.init() has been called first
501 try:
502 return SecretCache.get_variable(key)
503 except SecretCache.NotPresentException:
504 pass # continue business
505
506 var_val = None
507 # iterate over backends if not in cache (or expired)
508 for secrets_backend in ensure_secrets_loaded():
509 try:
510 var_val = secrets_backend.get_variable(key=key, team_name=team_name)
511 if var_val is not None:
512 break
513 except Exception:
514 log.exception(
515 "Unable to retrieve variable from secrets backend (%s). "
516 "Checking subsequent secrets backend.",
517 type(secrets_backend).__name__,
518 )
519
520 SecretCache.save_variable(key, var_val) # we save None as well
521 return var_val
522
523 @staticmethod
524 @provide_session
525 def get_team_name(variable_key: str, session=NEW_SESSION) -> str | None:
526 stmt = select(Variable.team_name).where(Variable.key == variable_key)
527 return session.scalar(stmt)
528
529 @staticmethod
530 @provide_session
531 def get_key_to_team_name_mapping(variable_keys: list[str], session=NEW_SESSION) -> dict[str, str | None]:
532 stmt = select(Variable.key, Variable.team_name).where(Variable.key.in_(variable_keys))
533 return {key: team_name for key, team_name in session.execute(stmt)}