Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/models/variable.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

215 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import contextlib 

21import json 

22import logging 

23import sys 

24import warnings 

25from typing import TYPE_CHECKING, Any 

26 

27from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, delete, or_, select 

28from sqlalchemy.dialects.mysql import MEDIUMTEXT 

29from sqlalchemy.orm import Mapped, declared_attr, reconstructor, synonym 

30 

31from airflow._shared.secrets_masker import mask_secret 

32from airflow.configuration import conf, ensure_secrets_loaded 

33from airflow.models.base import ID_LEN, Base 

34from airflow.models.crypto import get_fernet 

35from airflow.sdk import SecretCache 

36from airflow.secrets.metastore import MetastoreBackend 

37from airflow.utils.log.logging_mixin import LoggingMixin 

38from airflow.utils.session import NEW_SESSION, create_session, provide_session 

39from airflow.utils.sqlalchemy import get_dialect_name, mapped_column 

40 

41if TYPE_CHECKING: 

42 from sqlalchemy.dialects.mysql.dml import Insert as MySQLInsert 

43 from sqlalchemy.dialects.postgresql.dml import Insert as PostgreSQLInsert 

44 from sqlalchemy.dialects.sqlite.dml import Insert as SQLiteInsert 

45 from sqlalchemy.orm import Session 

46 

47log = logging.getLogger(__name__) 

48 

49 

50class Variable(Base, LoggingMixin): 

51 """A generic way to store and retrieve arbitrary content or settings as a simple key/value store.""" 

52 

53 __tablename__ = "variable" 

54 __NO_DEFAULT_SENTINEL = object() 

55 

56 id: Mapped[int] = mapped_column(Integer, primary_key=True) 

57 key: Mapped[str] = mapped_column(String(ID_LEN), unique=True) 

58 _val: Mapped[str] = mapped_column("val", Text().with_variant(MEDIUMTEXT, "mysql")) 

59 description: Mapped[str | None] = mapped_column(Text, nullable=True) 

60 is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False) 

61 team_name: Mapped[str | None] = mapped_column( 

62 String(50), 

63 ForeignKey("team.name", ondelete="SET NULL"), 

64 nullable=True, 

65 ) 

66 

67 def __init__(self, key=None, val=None, description=None, team_name=None): 

68 super().__init__() 

69 self.key = key 

70 self.val = val 

71 self.description = description 

72 self.team_name = team_name 

73 

74 @reconstructor 

75 def on_db_load(self): 

76 if self._val: 

77 mask_secret(self.val, self.key) 

78 

79 def __repr__(self): 

80 # Hiding the value 

81 return f"{self.key} : {self._val}" 

82 

83 def get_val(self): 

84 """Get Airflow Variable from Metadata DB and decode it using the Fernet Key.""" 

85 from cryptography.fernet import InvalidToken as InvalidFernetToken 

86 

87 if self._val is not None and self.is_encrypted: 

88 try: 

89 fernet = get_fernet() 

90 return fernet.decrypt(bytes(self._val, "utf-8")).decode() 

91 except InvalidFernetToken: 

92 self.log.error("Can't decrypt _val for key=%s, invalid token or value", self.key) 

93 return None 

94 except Exception: 

95 self.log.error("Can't decrypt _val for key=%s, FERNET_KEY configuration missing", self.key) 

96 return None 

97 else: 

98 return self._val 

99 

100 def set_val(self, value): 

101 """Encode the specified value with Fernet Key and store it in Variables Table.""" 

102 if value is not None: 

103 fernet = get_fernet() 

104 self._val = fernet.encrypt(bytes(value, "utf-8")).decode() 

105 self.is_encrypted = fernet.is_encrypted 

106 

107 @declared_attr 

108 def val(cls): 

109 """Get Airflow Variable from Metadata DB and decode it using the Fernet Key.""" 

110 return synonym("_val", descriptor=property(cls.get_val, cls.set_val)) 

111 

112 @classmethod 

113 def setdefault(cls, key, default, description=None, deserialize_json=False): 

114 """ 

115 Return the current value for a key or store the default value and return it. 

116 

117 Works the same as the Python builtin dict object. 

118 

119 :param key: Dict key for this Variable 

120 :param default: Default value to set and return if the variable 

121 isn't already in the DB 

122 :param description: Default value to set Description of the Variable 

123 :param deserialize_json: Store this as a JSON encoded value in the DB 

124 and un-encode it when retrieving a value 

125 :param session: Session 

126 :return: Mixed 

127 """ 

128 obj = Variable.get(key, default_var=None, deserialize_json=deserialize_json) 

129 if obj is None: 

130 if default is not None: 

131 Variable.set(key=key, value=default, description=description, serialize_json=deserialize_json) 

132 return default 

133 raise ValueError("Default Value must be set") 

134 return obj 

135 

136 @classmethod 

137 def get( 

138 cls, 

139 key: str, 

140 default_var: Any = __NO_DEFAULT_SENTINEL, 

141 deserialize_json: bool = False, 

142 team_name: str | None = None, 

143 ) -> Any: 

144 """ 

145 Get a value for an Airflow Variable Key. 

146 

147 :param key: Variable Key 

148 :param default_var: Default value of the Variable if the Variable doesn't exist 

149 :param deserialize_json: Deserialize the value to a Python dict 

150 :param team_name: Team name associated to the task trying to access the variable (if any) 

151 """ 

152 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still 

153 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big 

154 # back-compat layer 

155 

156 # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) 

157 # and should use the Task SDK API server path 

158 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): 

159 warnings.warn( 

160 "Using Variable.get from `airflow.models` is deprecated." 

161 "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead", 

162 DeprecationWarning, 

163 stacklevel=1, 

164 ) 

165 from airflow.sdk import Variable as TaskSDKVariable 

166 

167 default_kwargs = {} if default_var is cls.__NO_DEFAULT_SENTINEL else {"default": default_var} 

168 var_val = TaskSDKVariable.get(key, deserialize_json=deserialize_json, **default_kwargs) 

169 if isinstance(var_val, str): 

170 mask_secret(var_val, key) 

171 

172 return var_val 

173 

174 if team_name and not conf.getboolean("core", "multi_team"): 

175 raise ValueError( 

176 "Multi-team mode is not configured in the Airflow environment but the task trying to access the variable belongs to a team" 

177 ) 

178 

179 var_val = Variable.get_variable_from_secrets(key=key, team_name=team_name) 

180 if var_val is None: 

181 if default_var is not cls.__NO_DEFAULT_SENTINEL: 

182 return default_var 

183 raise KeyError(f"Variable {key} does not exist") 

184 if deserialize_json: 

185 obj = json.loads(var_val) 

186 mask_secret(obj, key) 

187 return obj 

188 mask_secret(var_val, key) 

189 return var_val 

190 

191 @staticmethod 

192 def set( 

193 key: str, 

194 value: Any, 

195 description: str | None = None, 

196 serialize_json: bool = False, 

197 team_name: str | None = None, 

198 session: Session | None = None, 

199 ) -> None: 

200 """ 

201 Set a value for an Airflow Variable with a given Key. 

202 

203 This operation overwrites an existing variable using the session's dialect-specific upsert operation. 

204 

205 :param key: Variable Key 

206 :param value: Value to set for the Variable 

207 :param description: Description of the Variable 

208 :param serialize_json: Serialize the value to a JSON string 

209 :param team_name: Team name associated to the variable (if any) 

210 :param session: optional session, use if provided or create a new one 

211 """ 

212 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still 

213 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big 

214 # back-compat layer 

215 

216 # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) 

217 # and should use the Task SDK API server path 

218 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): 

219 warnings.warn( 

220 "Using Variable.set from `airflow.models` is deprecated." 

221 "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead", 

222 DeprecationWarning, 

223 stacklevel=1, 

224 ) 

225 from airflow.sdk import Variable as TaskSDKVariable 

226 

227 TaskSDKVariable.set( 

228 key=key, 

229 value=value, 

230 description=description, 

231 serialize_json=serialize_json, 

232 ) 

233 return 

234 

235 if team_name and not conf.getboolean("core", "multi_team"): 

236 raise ValueError( 

237 "Multi-team mode is not configured in the Airflow environment. To assign a team to a variable, multi-mode must be enabled." 

238 ) 

239 

240 # check if the secret exists in the custom secrets' backend. 

241 Variable.check_for_write_conflict(key=key) 

242 if serialize_json: 

243 stored_value = json.dumps(value, indent=2) 

244 else: 

245 stored_value = str(value) 

246 

247 ctx: contextlib.AbstractContextManager 

248 if session is not None: 

249 ctx = contextlib.nullcontext(session) 

250 else: 

251 ctx = create_session() 

252 

253 with ctx as session: 

254 new_variable = Variable(key=key, val=stored_value, description=description, team_name=team_name) 

255 

256 val = new_variable._val 

257 is_encrypted = new_variable.is_encrypted 

258 

259 # Create dialect-specific upsert statement 

260 dialect_name = get_dialect_name(session) 

261 stmt: MySQLInsert | PostgreSQLInsert | SQLiteInsert 

262 

263 if dialect_name == "postgresql": 

264 from sqlalchemy.dialects.postgresql import insert as pg_insert 

265 

266 pg_stmt = pg_insert(Variable).values( 

267 key=key, 

268 val=val, 

269 description=description, 

270 is_encrypted=is_encrypted, 

271 team_name=team_name, 

272 ) 

273 stmt = pg_stmt.on_conflict_do_update( 

274 index_elements=["key"], 

275 set_=dict( 

276 val=val, 

277 description=description, 

278 is_encrypted=is_encrypted, 

279 team_name=team_name, 

280 ), 

281 ) 

282 elif dialect_name == "mysql": 

283 from sqlalchemy.dialects.mysql import insert as mysql_insert 

284 

285 mysql_stmt = mysql_insert(Variable).values( 

286 key=key, 

287 val=val, 

288 description=description, 

289 is_encrypted=is_encrypted, 

290 team_name=team_name, 

291 ) 

292 stmt = mysql_stmt.on_duplicate_key_update( 

293 val=val, 

294 description=description, 

295 is_encrypted=is_encrypted, 

296 team_name=team_name, 

297 ) 

298 else: 

299 from sqlalchemy.dialects.sqlite import insert as sqlite_insert 

300 

301 sqlite_stmt = sqlite_insert(Variable).values( 

302 key=key, 

303 val=val, 

304 description=description, 

305 is_encrypted=is_encrypted, 

306 team_name=team_name, 

307 ) 

308 stmt = sqlite_stmt.on_conflict_do_update( 

309 index_elements=["key"], 

310 set_=dict( 

311 val=val, 

312 description=description, 

313 is_encrypted=is_encrypted, 

314 team_name=team_name, 

315 ), 

316 ) 

317 

318 session.execute(stmt) 

319 # invalidate key in cache for faster propagation 

320 # we cannot save the value set because it's possible that it's shadowed by a custom backend 

321 # (see call to check_for_write_conflict above) 

322 SecretCache.invalidate_variable(key) 

323 

324 @staticmethod 

325 def update( 

326 key: str, 

327 value: Any, 

328 serialize_json: bool = False, 

329 team_name: str | None = None, 

330 session: Session | None = None, 

331 ) -> None: 

332 """ 

333 Update a given Airflow Variable with the Provided value. 

334 

335 :param key: Variable Key 

336 :param value: Value to set for the Variable 

337 :param serialize_json: Serialize the value to a JSON string 

338 :param team_name: Team name associated to the variable (if any) 

339 :param session: optional session, use if provided or create a new one 

340 """ 

341 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still 

342 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big 

343 # back-compat layer 

344 

345 # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) 

346 # and should use the Task SDK API server path 

347 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): 

348 warnings.warn( 

349 "Using Variable.update from `airflow.models` is deprecated." 

350 "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.", 

351 DeprecationWarning, 

352 stacklevel=1, 

353 ) 

354 from airflow.sdk import Variable as TaskSDKVariable 

355 

356 # set is an upsert command, it can handle updates too 

357 TaskSDKVariable.set( 

358 key=key, 

359 value=value, 

360 serialize_json=serialize_json, 

361 ) 

362 return 

363 

364 if team_name and not conf.getboolean("core", "multi_team"): 

365 raise ValueError( 

366 "Multi-team mode is not configured in the Airflow environment. To assign a team to a variable, multi-mode must be enabled." 

367 ) 

368 

369 Variable.check_for_write_conflict(key=key) 

370 

371 if Variable.get_variable_from_secrets(key=key, team_name=team_name) is None: 

372 raise KeyError(f"Variable {key} does not exist") 

373 

374 ctx: contextlib.AbstractContextManager 

375 if session is not None: 

376 ctx = contextlib.nullcontext(session) 

377 else: 

378 ctx = create_session() 

379 

380 with ctx as session: 

381 obj = session.scalar( 

382 select(Variable).where( 

383 Variable.key == key, or_(Variable.team_name == team_name, Variable.team_name.is_(None)) 

384 ) 

385 ) 

386 if obj is None: 

387 raise AttributeError(f"Variable {key} does not exist in the Database and cannot be updated.") 

388 

389 Variable.set( 

390 key=key, 

391 value=value, 

392 description=obj.description, 

393 serialize_json=serialize_json, 

394 session=session, 

395 ) 

396 

397 @staticmethod 

398 def delete(key: str, team_name: str | None = None, session: Session | None = None) -> int: 

399 """ 

400 Delete an Airflow Variable for a given key. 

401 

402 :param key: Variable Keys 

403 :param team_name: Team name associated to the task trying to delete the variable (if any) 

404 :param session: optional session, use if provided or create a new one 

405 """ 

406 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still 

407 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big 

408 # back-compat layer 

409 

410 # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) 

411 # and should use the Task SDK API server path 

412 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): 

413 warnings.warn( 

414 "Using Variable.delete from `airflow.models` is deprecated." 

415 "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead", 

416 DeprecationWarning, 

417 stacklevel=1, 

418 ) 

419 from airflow.sdk import Variable as TaskSDKVariable 

420 

421 TaskSDKVariable.delete( 

422 key=key, 

423 ) 

424 return 1 

425 

426 if team_name and not conf.getboolean("core", "multi_team"): 

427 raise ValueError( 

428 "Multi-team mode is not configured in the Airflow environment but the task trying to delete the variable belongs to a team" 

429 ) 

430 

431 ctx: contextlib.AbstractContextManager 

432 if session is not None: 

433 ctx = contextlib.nullcontext(session) 

434 else: 

435 ctx = create_session() 

436 

437 with ctx as session: 

438 result = session.execute( 

439 delete(Variable).where( 

440 Variable.key == key, or_(Variable.team_name == team_name, Variable.team_name.is_(None)) 

441 ) 

442 ) 

443 rows = getattr(result, "rowcount", 0) or 0 

444 SecretCache.invalidate_variable(key) 

445 return rows 

446 

447 def rotate_fernet_key(self): 

448 """Rotate Fernet Key.""" 

449 fernet = get_fernet() 

450 if self._val and self.is_encrypted: 

451 self._val = fernet.rotate(self._val.encode("utf-8")).decode() 

452 

453 @staticmethod 

454 def check_for_write_conflict(key: str) -> None: 

455 """ 

456 Log a warning if a variable exists outside the metastore. 

457 

458 If we try to write a variable to the metastore while the same key 

459 exists in an environment variable or custom secrets backend, then 

460 subsequent reads will not read the set value. 

461 

462 :param key: Variable Key 

463 """ 

464 for secrets_backend in ensure_secrets_loaded(): 

465 if not isinstance(secrets_backend, MetastoreBackend): 

466 try: 

467 var_val = secrets_backend.get_variable(key=key) 

468 if var_val is not None: 

469 _backend_name = type(secrets_backend).__name__ 

470 log.warning( 

471 "The variable %s is defined in the %s secrets backend, which takes " 

472 "precedence over reading from the database. The value in the database will be " 

473 "updated, but to read it you have to delete the conflicting variable " 

474 "from %s", 

475 key, 

476 _backend_name, 

477 _backend_name, 

478 ) 

479 return 

480 except Exception: 

481 log.exception( 

482 "Unable to retrieve variable from secrets backend (%s). " 

483 "Checking subsequent secrets backend.", 

484 type(secrets_backend).__name__, 

485 ) 

486 return None 

487 

488 @staticmethod 

489 def get_variable_from_secrets(key: str, team_name: str | None = None) -> str | None: 

490 """ 

491 Get Airflow Variable by iterating over all Secret Backends. 

492 

493 :param key: Variable Key 

494 :param team_name: Team name associated to the task trying to access the variable (if any) 

495 :return: Variable Value 

496 """ 

497 # Disable cache if the variable belongs to a team. We might enable it later 

498 if not team_name: 

499 # check cache first 

500 # enabled only if SecretCache.init() has been called first 

501 try: 

502 return SecretCache.get_variable(key) 

503 except SecretCache.NotPresentException: 

504 pass # continue business 

505 

506 var_val = None 

507 # iterate over backends if not in cache (or expired) 

508 for secrets_backend in ensure_secrets_loaded(): 

509 try: 

510 var_val = secrets_backend.get_variable(key=key, team_name=team_name) 

511 if var_val is not None: 

512 break 

513 except Exception: 

514 log.exception( 

515 "Unable to retrieve variable from secrets backend (%s). " 

516 "Checking subsequent secrets backend.", 

517 type(secrets_backend).__name__, 

518 ) 

519 

520 SecretCache.save_variable(key, var_val) # we save None as well 

521 return var_val 

522 

523 @staticmethod 

524 @provide_session 

525 def get_team_name(variable_key: str, session=NEW_SESSION) -> str | None: 

526 stmt = select(Variable.team_name).where(Variable.key == variable_key) 

527 return session.scalar(stmt) 

528 

529 @staticmethod 

530 @provide_session 

531 def get_key_to_team_name_mapping(variable_keys: list[str], session=NEW_SESSION) -> dict[str, str | None]: 

532 stmt = select(Variable.key, Variable.team_name).where(Variable.key.in_(variable_keys)) 

533 return {key: team_name for key, team_name in session.execute(stmt)}