Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/models/variable.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

217 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import contextlib 

21import json 

22import logging 

23import sys 

24import warnings 

25from typing import TYPE_CHECKING, Any 

26 

27from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, delete, or_, select 

28from sqlalchemy.dialects.mysql import MEDIUMTEXT 

29from sqlalchemy.orm import Mapped, declared_attr, reconstructor, synonym 

30 

31from airflow._shared.secrets_masker import mask_secret 

32from airflow.configuration import conf, ensure_secrets_loaded 

33from airflow.models.base import ID_LEN, Base 

34from airflow.models.crypto import get_fernet 

35from airflow.secrets.metastore import MetastoreBackend 

36from airflow.utils.log.logging_mixin import LoggingMixin 

37from airflow.utils.session import NEW_SESSION, create_session, provide_session 

38from airflow.utils.sqlalchemy import get_dialect_name, mapped_column 

39 

40if TYPE_CHECKING: 

41 from sqlalchemy.dialects.mysql.dml import Insert as MySQLInsert 

42 from sqlalchemy.dialects.postgresql.dml import Insert as PostgreSQLInsert 

43 from sqlalchemy.dialects.sqlite.dml import Insert as SQLiteInsert 

44 from sqlalchemy.orm import Session 

45 

46log = logging.getLogger(__name__) 

47 

48 

49class Variable(Base, LoggingMixin): 

50 """A generic way to store and retrieve arbitrary content or settings as a simple key/value store.""" 

51 

52 __tablename__ = "variable" 

53 __NO_DEFAULT_SENTINEL = object() 

54 

55 id: Mapped[int] = mapped_column(Integer, primary_key=True) 

56 key: Mapped[str] = mapped_column(String(ID_LEN), unique=True) 

57 _val: Mapped[str] = mapped_column("val", Text().with_variant(MEDIUMTEXT, "mysql")) 

58 description: Mapped[str | None] = mapped_column(Text, nullable=True) 

59 is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False) 

60 team_name: Mapped[str | None] = mapped_column( 

61 String(50), 

62 ForeignKey("team.name", ondelete="SET NULL"), 

63 nullable=True, 

64 ) 

65 

66 def __init__(self, key=None, val=None, description=None, team_name=None): 

67 super().__init__() 

68 self.key = key 

69 self.val = val 

70 self.description = description 

71 self.team_name = team_name 

72 

73 @reconstructor 

74 def on_db_load(self): 

75 if self._val: 

76 mask_secret(self.val, self.key) 

77 

78 def __repr__(self): 

79 # Hiding the value 

80 return f"{self.key} : {self._val}" 

81 

82 def get_val(self): 

83 """Get Airflow Variable from Metadata DB and decode it using the Fernet Key.""" 

84 from cryptography.fernet import InvalidToken as InvalidFernetToken 

85 

86 if self._val is not None and self.is_encrypted: 

87 try: 

88 fernet = get_fernet() 

89 return fernet.decrypt(bytes(self._val, "utf-8")).decode() 

90 except InvalidFernetToken: 

91 self.log.error("Can't decrypt _val for key=%s, invalid token or value", self.key) 

92 return None 

93 except Exception: 

94 self.log.error("Can't decrypt _val for key=%s, FERNET_KEY configuration missing", self.key) 

95 return None 

96 else: 

97 return self._val 

98 

99 def set_val(self, value): 

100 """Encode the specified value with Fernet Key and store it in Variables Table.""" 

101 if value is not None: 

102 fernet = get_fernet() 

103 self._val = fernet.encrypt(bytes(value, "utf-8")).decode() 

104 self.is_encrypted = fernet.is_encrypted 

105 

106 @declared_attr 

107 def val(cls): 

108 """Get Airflow Variable from Metadata DB and decode it using the Fernet Key.""" 

109 return synonym("_val", descriptor=property(cls.get_val, cls.set_val)) 

110 

111 @classmethod 

112 def setdefault(cls, key, default, description=None, deserialize_json=False): 

113 """ 

114 Return the current value for a key or store the default value and return it. 

115 

116 Works the same as the Python builtin dict object. 

117 

118 :param key: Dict key for this Variable 

119 :param default: Default value to set and return if the variable 

120 isn't already in the DB 

121 :param description: Default value to set Description of the Variable 

122 :param deserialize_json: Store this as a JSON encoded value in the DB 

123 and un-encode it when retrieving a value 

124 :param session: Session 

125 :return: Mixed 

126 """ 

127 obj = Variable.get(key, default_var=None, deserialize_json=deserialize_json) 

128 if obj is None: 

129 if default is not None: 

130 Variable.set(key=key, value=default, description=description, serialize_json=deserialize_json) 

131 return default 

132 raise ValueError("Default Value must be set") 

133 return obj 

134 

135 @classmethod 

136 def get( 

137 cls, 

138 key: str, 

139 default_var: Any = __NO_DEFAULT_SENTINEL, 

140 deserialize_json: bool = False, 

141 team_name: str | None = None, 

142 ) -> Any: 

143 """ 

144 Get a value for an Airflow Variable Key. 

145 

146 :param key: Variable Key 

147 :param default_var: Default value of the Variable if the Variable doesn't exist 

148 :param deserialize_json: Deserialize the value to a Python dict 

149 :param team_name: Team name associated to the task trying to access the variable (if any) 

150 """ 

151 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still 

152 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big 

153 # back-compat layer 

154 

155 # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) 

156 # and should use the Task SDK API server path 

157 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): 

158 warnings.warn( 

159 "Using Variable.get from `airflow.models` is deprecated." 

160 "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead", 

161 DeprecationWarning, 

162 stacklevel=1, 

163 ) 

164 from airflow.sdk import Variable as TaskSDKVariable 

165 

166 default_kwargs = {} if default_var is cls.__NO_DEFAULT_SENTINEL else {"default": default_var} 

167 var_val = TaskSDKVariable.get(key, deserialize_json=deserialize_json, **default_kwargs) 

168 if isinstance(var_val, str): 

169 mask_secret(var_val, key) 

170 

171 return var_val 

172 

173 if team_name and not conf.getboolean("core", "multi_team"): 

174 raise ValueError( 

175 "Multi-team mode is not configured in the Airflow environment but the task trying to access the variable belongs to a team" 

176 ) 

177 

178 var_val = Variable.get_variable_from_secrets(key=key, team_name=team_name) 

179 if var_val is None: 

180 if default_var is not cls.__NO_DEFAULT_SENTINEL: 

181 return default_var 

182 raise KeyError(f"Variable {key} does not exist") 

183 if deserialize_json: 

184 obj = json.loads(var_val) 

185 mask_secret(obj, key) 

186 return obj 

187 mask_secret(var_val, key) 

188 return var_val 

189 

190 @staticmethod 

191 def set( 

192 key: str, 

193 value: Any, 

194 description: str | None = None, 

195 serialize_json: bool = False, 

196 team_name: str | None = None, 

197 session: Session | None = None, 

198 ) -> None: 

199 """ 

200 Set a value for an Airflow Variable with a given Key. 

201 

202 This operation overwrites an existing variable using the session's dialect-specific upsert operation. 

203 

204 :param key: Variable Key 

205 :param value: Value to set for the Variable 

206 :param description: Description of the Variable 

207 :param serialize_json: Serialize the value to a JSON string 

208 :param team_name: Team name associated to the variable (if any) 

209 :param session: optional session, use if provided or create a new one 

210 """ 

211 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still 

212 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big 

213 # back-compat layer 

214 

215 # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) 

216 # and should use the Task SDK API server path 

217 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): 

218 warnings.warn( 

219 "Using Variable.set from `airflow.models` is deprecated." 

220 "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead", 

221 DeprecationWarning, 

222 stacklevel=1, 

223 ) 

224 from airflow.sdk import Variable as TaskSDKVariable 

225 

226 TaskSDKVariable.set( 

227 key=key, 

228 value=value, 

229 description=description, 

230 serialize_json=serialize_json, 

231 ) 

232 return 

233 

234 if team_name and not conf.getboolean("core", "multi_team"): 

235 raise ValueError( 

236 "Multi-team mode is not configured in the Airflow environment. To assign a team to a variable, multi-mode must be enabled." 

237 ) 

238 

239 # check if the secret exists in the custom secrets' backend. 

240 from airflow.sdk import SecretCache 

241 

242 Variable.check_for_write_conflict(key=key) 

243 if serialize_json: 

244 stored_value = json.dumps(value, indent=2) 

245 else: 

246 stored_value = str(value) 

247 

248 ctx: contextlib.AbstractContextManager 

249 if session is not None: 

250 ctx = contextlib.nullcontext(session) 

251 else: 

252 ctx = create_session() 

253 

254 with ctx as session: 

255 new_variable = Variable(key=key, val=stored_value, description=description, team_name=team_name) 

256 

257 val = new_variable._val 

258 is_encrypted = new_variable.is_encrypted 

259 

260 # Create dialect-specific upsert statement 

261 dialect_name = get_dialect_name(session) 

262 stmt: MySQLInsert | PostgreSQLInsert | SQLiteInsert 

263 

264 if dialect_name == "postgresql": 

265 from sqlalchemy.dialects.postgresql import insert as pg_insert 

266 

267 pg_stmt = pg_insert(Variable).values( 

268 key=key, 

269 val=val, 

270 description=description, 

271 is_encrypted=is_encrypted, 

272 team_name=team_name, 

273 ) 

274 stmt = pg_stmt.on_conflict_do_update( 

275 index_elements=["key"], 

276 set_=dict( 

277 val=val, 

278 description=description, 

279 is_encrypted=is_encrypted, 

280 team_name=team_name, 

281 ), 

282 ) 

283 elif dialect_name == "mysql": 

284 from sqlalchemy.dialects.mysql import insert as mysql_insert 

285 

286 mysql_stmt = mysql_insert(Variable).values( 

287 key=key, 

288 val=val, 

289 description=description, 

290 is_encrypted=is_encrypted, 

291 team_name=team_name, 

292 ) 

293 stmt = mysql_stmt.on_duplicate_key_update( 

294 val=val, 

295 description=description, 

296 is_encrypted=is_encrypted, 

297 team_name=team_name, 

298 ) 

299 else: 

300 from sqlalchemy.dialects.sqlite import insert as sqlite_insert 

301 

302 sqlite_stmt = sqlite_insert(Variable).values( 

303 key=key, 

304 val=val, 

305 description=description, 

306 is_encrypted=is_encrypted, 

307 team_name=team_name, 

308 ) 

309 stmt = sqlite_stmt.on_conflict_do_update( 

310 index_elements=["key"], 

311 set_=dict( 

312 val=val, 

313 description=description, 

314 is_encrypted=is_encrypted, 

315 team_name=team_name, 

316 ), 

317 ) 

318 

319 session.execute(stmt) 

320 # invalidate key in cache for faster propagation 

321 # we cannot save the value set because it's possible that it's shadowed by a custom backend 

322 # (see call to check_for_write_conflict above) 

323 SecretCache.invalidate_variable(key) 

324 

325 @staticmethod 

326 def update( 

327 key: str, 

328 value: Any, 

329 serialize_json: bool = False, 

330 team_name: str | None = None, 

331 session: Session | None = None, 

332 ) -> None: 

333 """ 

334 Update a given Airflow Variable with the Provided value. 

335 

336 :param key: Variable Key 

337 :param value: Value to set for the Variable 

338 :param serialize_json: Serialize the value to a JSON string 

339 :param team_name: Team name associated to the variable (if any) 

340 :param session: optional session, use if provided or create a new one 

341 """ 

342 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still 

343 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big 

344 # back-compat layer 

345 

346 # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) 

347 # and should use the Task SDK API server path 

348 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): 

349 warnings.warn( 

350 "Using Variable.update from `airflow.models` is deprecated." 

351 "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.", 

352 DeprecationWarning, 

353 stacklevel=1, 

354 ) 

355 from airflow.sdk import Variable as TaskSDKVariable 

356 

357 # set is an upsert command, it can handle updates too 

358 TaskSDKVariable.set( 

359 key=key, 

360 value=value, 

361 serialize_json=serialize_json, 

362 ) 

363 return 

364 

365 if team_name and not conf.getboolean("core", "multi_team"): 

366 raise ValueError( 

367 "Multi-team mode is not configured in the Airflow environment. To assign a team to a variable, multi-mode must be enabled." 

368 ) 

369 

370 Variable.check_for_write_conflict(key=key) 

371 

372 if Variable.get_variable_from_secrets(key=key, team_name=team_name) is None: 

373 raise KeyError(f"Variable {key} does not exist") 

374 

375 ctx: contextlib.AbstractContextManager 

376 if session is not None: 

377 ctx = contextlib.nullcontext(session) 

378 else: 

379 ctx = create_session() 

380 

381 with ctx as session: 

382 obj = session.scalar( 

383 select(Variable).where( 

384 Variable.key == key, or_(Variable.team_name == team_name, Variable.team_name.is_(None)) 

385 ) 

386 ) 

387 if obj is None: 

388 raise AttributeError(f"Variable {key} does not exist in the Database and cannot be updated.") 

389 

390 Variable.set( 

391 key=key, 

392 value=value, 

393 description=obj.description, 

394 serialize_json=serialize_json, 

395 session=session, 

396 ) 

397 

398 @staticmethod 

399 def delete(key: str, team_name: str | None = None, session: Session | None = None) -> int: 

400 """ 

401 Delete an Airflow Variable for a given key. 

402 

403 :param key: Variable Keys 

404 :param team_name: Team name associated to the task trying to delete the variable (if any) 

405 :param session: optional session, use if provided or create a new one 

406 """ 

407 # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still 

408 # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big 

409 # back-compat layer 

410 

411 # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) 

412 # and should use the Task SDK API server path 

413 if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): 

414 warnings.warn( 

415 "Using Variable.delete from `airflow.models` is deprecated." 

416 "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead", 

417 DeprecationWarning, 

418 stacklevel=1, 

419 ) 

420 from airflow.sdk import Variable as TaskSDKVariable 

421 

422 TaskSDKVariable.delete( 

423 key=key, 

424 ) 

425 return 1 

426 

427 if team_name and not conf.getboolean("core", "multi_team"): 

428 raise ValueError( 

429 "Multi-team mode is not configured in the Airflow environment but the task trying to delete the variable belongs to a team" 

430 ) 

431 

432 from airflow.sdk import SecretCache 

433 

434 ctx: contextlib.AbstractContextManager 

435 if session is not None: 

436 ctx = contextlib.nullcontext(session) 

437 else: 

438 ctx = create_session() 

439 

440 with ctx as session: 

441 result = session.execute( 

442 delete(Variable).where( 

443 Variable.key == key, or_(Variable.team_name == team_name, Variable.team_name.is_(None)) 

444 ) 

445 ) 

446 rows = getattr(result, "rowcount", 0) or 0 

447 SecretCache.invalidate_variable(key) 

448 return rows 

449 

450 def rotate_fernet_key(self): 

451 """Rotate Fernet Key.""" 

452 fernet = get_fernet() 

453 if self._val and self.is_encrypted: 

454 self._val = fernet.rotate(self._val.encode("utf-8")).decode() 

455 

456 @staticmethod 

457 def check_for_write_conflict(key: str) -> None: 

458 """ 

459 Log a warning if a variable exists outside the metastore. 

460 

461 If we try to write a variable to the metastore while the same key 

462 exists in an environment variable or custom secrets backend, then 

463 subsequent reads will not read the set value. 

464 

465 :param key: Variable Key 

466 """ 

467 for secrets_backend in ensure_secrets_loaded(): 

468 if not isinstance(secrets_backend, MetastoreBackend): 

469 try: 

470 var_val = secrets_backend.get_variable(key=key) 

471 if var_val is not None: 

472 _backend_name = type(secrets_backend).__name__ 

473 log.warning( 

474 "The variable %s is defined in the %s secrets backend, which takes " 

475 "precedence over reading from the database. The value in the database will be " 

476 "updated, but to read it you have to delete the conflicting variable " 

477 "from %s", 

478 key, 

479 _backend_name, 

480 _backend_name, 

481 ) 

482 return 

483 except Exception: 

484 log.exception( 

485 "Unable to retrieve variable from secrets backend (%s). " 

486 "Checking subsequent secrets backend.", 

487 type(secrets_backend).__name__, 

488 ) 

489 return None 

490 

491 @staticmethod 

492 def get_variable_from_secrets(key: str, team_name: str | None = None) -> str | None: 

493 """ 

494 Get Airflow Variable by iterating over all Secret Backends. 

495 

496 :param key: Variable Key 

497 :param team_name: Team name associated to the task trying to access the variable (if any) 

498 :return: Variable Value 

499 """ 

500 from airflow.sdk import SecretCache 

501 

502 # Disable cache if the variable belongs to a team. We might enable it later 

503 if not team_name: 

504 # check cache first 

505 # enabled only if SecretCache.init() has been called first 

506 try: 

507 return SecretCache.get_variable(key) 

508 except SecretCache.NotPresentException: 

509 pass # continue business 

510 

511 var_val = None 

512 # iterate over backends if not in cache (or expired) 

513 for secrets_backend in ensure_secrets_loaded(): 

514 try: 

515 var_val = secrets_backend.get_variable(key=key, team_name=team_name) 

516 if var_val is not None: 

517 break 

518 except Exception: 

519 log.exception( 

520 "Unable to retrieve variable from secrets backend (%s). " 

521 "Checking subsequent secrets backend.", 

522 type(secrets_backend).__name__, 

523 ) 

524 

525 SecretCache.save_variable(key, var_val) # we save None as well 

526 return var_val 

527 

528 @staticmethod 

529 @provide_session 

530 def get_team_name(variable_key: str, session=NEW_SESSION) -> str | None: 

531 stmt = select(Variable.team_name).where(Variable.key == variable_key) 

532 return session.scalar(stmt) 

533 

534 @staticmethod 

535 @provide_session 

536 def get_key_to_team_name_mapping(variable_keys: list[str], session=NEW_SESSION) -> dict[str, str | None]: 

537 stmt = select(Variable.key, Variable.team_name).where(Variable.key.in_(variable_keys)) 

538 return {key: team_name for key, team_name in session.execute(stmt)}