Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/utils/sqlalchemy.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

209 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import contextlib 

21import copy 

22import datetime 

23import json 

24import logging 

25from importlib import metadata 

26from typing import TYPE_CHECKING, Any, Generator, Iterable, overload 

27 

28from dateutil import relativedelta 

29from packaging import version 

30from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_ 

31from sqlalchemy.dialects import mysql 

32from sqlalchemy.types import JSON, Text, TypeDecorator 

33 

34from airflow.configuration import conf 

35from airflow.serialization.enums import Encoding 

36from airflow.utils.timezone import make_naive, utc 

37 

38if TYPE_CHECKING: 

39 from kubernetes.client.models.v1_pod import V1Pod 

40 from sqlalchemy.exc import OperationalError 

41 from sqlalchemy.orm import Query, Session 

42 from sqlalchemy.sql import ColumnElement, Select 

43 from sqlalchemy.sql.expression import ColumnOperators 

44 from sqlalchemy.types import TypeEngine 

45 

46log = logging.getLogger(__name__) 

47 

48 

49class UtcDateTime(TypeDecorator): 

50 """ 

51 Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` option, with some differences. 

52 

53 - Never silently take naive :class:`~datetime.datetime`, instead it 

54 always raise :exc:`ValueError` unless time zone aware value. 

55 - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo` 

56 is always converted to UTC. 

57 - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`, 

58 it never return naive :class:`~datetime.datetime`, but time zone 

59 aware value, even with SQLite or MySQL. 

60 - Always returns TIMESTAMP in UTC. 

61 """ 

62 

63 impl = TIMESTAMP(timezone=True) 

64 

65 cache_ok = True 

66 

67 def process_bind_param(self, value, dialect): 

68 if not isinstance(value, datetime.datetime): 

69 if value is None: 

70 return None 

71 raise TypeError(f"expected datetime.datetime, not {value!r}") 

72 elif value.tzinfo is None: 

73 raise ValueError("naive datetime is disallowed") 

74 elif dialect.name == "mysql": 

75 # For mysql versions prior 8.0.19 we should send timestamps as naive values in UTC 

76 # see: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-literals.html 

77 return make_naive(value, timezone=utc) 

78 return value.astimezone(utc) 

79 

80 def process_result_value(self, value, dialect): 

81 """ 

82 Process DateTimes from the DB making sure to always return UTC. 

83 

84 Not using timezone.convert_to_utc as that converts to configured TIMEZONE 

85 while the DB might be running with some other setting. We assume UTC 

86 datetimes in the database. 

87 """ 

88 if value is not None: 

89 if value.tzinfo is None: 

90 value = value.replace(tzinfo=utc) 

91 else: 

92 value = value.astimezone(utc) 

93 

94 return value 

95 

96 def load_dialect_impl(self, dialect): 

97 if dialect.name == "mysql": 

98 return mysql.TIMESTAMP(fsp=6) 

99 return super().load_dialect_impl(dialect) 

100 

101 

102class ExtendedJSON(TypeDecorator): 

103 """ 

104 A version of the JSON column that uses the Airflow extended JSON serialization. 

105 

106 See airflow.serialization. 

107 """ 

108 

109 impl = Text 

110 

111 cache_ok = True 

112 

113 def load_dialect_impl(self, dialect) -> TypeEngine: 

114 return dialect.type_descriptor(JSON) 

115 

116 def process_bind_param(self, value, dialect): 

117 from airflow.serialization.serialized_objects import BaseSerialization 

118 

119 if value is None: 

120 return None 

121 

122 return BaseSerialization.serialize(value) 

123 

124 def process_result_value(self, value, dialect): 

125 from airflow.serialization.serialized_objects import BaseSerialization 

126 

127 if value is None: 

128 return None 

129 

130 return BaseSerialization.deserialize(value) 

131 

132 

133def sanitize_for_serialization(obj: V1Pod): 

134 """ 

135 Convert pod to dict.... but *safely*. 

136 

137 When pod objects created with one k8s version are unpickled in a python 

138 env with a more recent k8s version (in which the object attrs may have 

139 changed) the unpickled obj may throw an error because the attr 

140 expected on new obj may not be there on the unpickled obj. 

141 

142 This function still converts the pod to a dict; the only difference is 

143 it populates missing attrs with None. You may compare with 

144 https://github.com/kubernetes-client/python/blob/5a96bbcbe21a552cc1f9cda13e0522fafb0dbac8/kubernetes/client/api_client.py#L202 

145 

146 If obj is None, return None. 

147 If obj is str, int, long, float, bool, return directly. 

148 If obj is datetime.datetime, datetime.date 

149 convert to string in iso8601 format. 

150 If obj is list, sanitize each element in the list. 

151 If obj is dict, return the dict. 

152 If obj is OpenAPI model, return the properties dict. 

153 

154 :param obj: The data to serialize. 

155 :return: The serialized form of data. 

156 

157 :meta private: 

158 """ 

159 if obj is None: 

160 return None 

161 elif isinstance(obj, (float, bool, bytes, str, int)): 

162 return obj 

163 elif isinstance(obj, list): 

164 return [sanitize_for_serialization(sub_obj) for sub_obj in obj] 

165 elif isinstance(obj, tuple): 

166 return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj) 

167 elif isinstance(obj, (datetime.datetime, datetime.date)): 

168 return obj.isoformat() 

169 

170 if isinstance(obj, dict): 

171 obj_dict = obj 

172 else: 

173 obj_dict = { 

174 obj.attribute_map[attr]: getattr(obj, attr) 

175 for attr, _ in obj.openapi_types.items() 

176 # below is the only line we change, and we just add default=None for getattr 

177 if getattr(obj, attr, None) is not None 

178 } 

179 

180 return {key: sanitize_for_serialization(val) for key, val in obj_dict.items()} 

181 

182 

183def ensure_pod_is_valid_after_unpickling(pod: V1Pod) -> V1Pod | None: 

184 """ 

185 Convert pod to json and back so that pod is safe. 

186 

187 The pod_override in executor_config is a V1Pod object. 

188 Such objects created with one k8s version, when unpickled in 

189 an env with upgraded k8s version, may blow up when 

190 `to_dict` is called, because openapi client code gen calls 

191 getattr on all attrs in openapi_types for each object, and when 

192 new attrs are added to that list, getattr will fail. 

193 

194 Here we re-serialize it to ensure it is not going to blow up. 

195 

196 :meta private: 

197 """ 

198 try: 

199 # if to_dict works, the pod is fine 

200 pod.to_dict() 

201 return pod 

202 except AttributeError: 

203 pass 

204 try: 

205 from kubernetes.client.models.v1_pod import V1Pod 

206 except ImportError: 

207 return None 

208 if not isinstance(pod, V1Pod): 

209 return None 

210 try: 

211 try: 

212 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator 

213 except ImportError: 

214 from airflow.kubernetes.pre_7_4_0_compatibility.pod_generator import ( # type: ignore[assignment] 

215 PodGenerator, 

216 ) 

217 # now we actually reserialize / deserialize the pod 

218 pod_dict = sanitize_for_serialization(pod) 

219 return PodGenerator.deserialize_model_dict(pod_dict) 

220 except Exception: 

221 return None 

222 

223 

224class ExecutorConfigType(PickleType): 

225 """ 

226 Adds special handling for K8s executor config. 

227 

228 If we unpickle a k8s object that was pickled under an earlier k8s library version, then 

229 the unpickled object may throw an error when to_dict is called. To be more tolerant of 

230 version changes we convert to JSON using Airflow's serializer before pickling. 

231 """ 

232 

233 cache_ok = True 

234 

235 def bind_processor(self, dialect): 

236 from airflow.serialization.serialized_objects import BaseSerialization 

237 

238 super_process = super().bind_processor(dialect) 

239 

240 def process(value): 

241 val_copy = copy.copy(value) 

242 if isinstance(val_copy, dict) and "pod_override" in val_copy: 

243 val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"]) 

244 return super_process(val_copy) 

245 

246 return process 

247 

248 def result_processor(self, dialect, coltype): 

249 from airflow.serialization.serialized_objects import BaseSerialization 

250 

251 super_process = super().result_processor(dialect, coltype) 

252 

253 def process(value): 

254 value = super_process(value) # unpickle 

255 

256 if isinstance(value, dict) and "pod_override" in value: 

257 pod_override = value["pod_override"] 

258 

259 if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE): 

260 # If pod_override was serialized with Airflow's BaseSerialization, deserialize it 

261 value["pod_override"] = BaseSerialization.deserialize(pod_override) 

262 else: 

263 # backcompat path 

264 # we no longer pickle raw pods but this code may be reached 

265 # when accessing executor configs created in a prior version 

266 new_pod = ensure_pod_is_valid_after_unpickling(pod_override) 

267 if new_pod: 

268 value["pod_override"] = new_pod 

269 return value 

270 

271 return process 

272 

273 def compare_values(self, x, y): 

274 """ 

275 Compare x and y using self.comparator if available. Else, use __eq__. 

276 

277 The TaskInstance.executor_config attribute is a pickled object that may contain kubernetes objects. 

278 

279 If the installed library version has changed since the object was originally pickled, 

280 due to the underlying ``__eq__`` method on these objects (which converts them to JSON), 

281 we may encounter attribute errors. In this case we should replace the stored object. 

282 

283 From https://github.com/apache/airflow/pull/24356 we use our serializer to store 

284 k8s objects, but there could still be raw pickled k8s objects in the database, 

285 stored from earlier version, so we still compare them defensively here. 

286 """ 

287 if self.comparator: 

288 return self.comparator(x, y) 

289 else: 

290 try: 

291 return x == y 

292 except AttributeError: 

293 return False 

294 

295 

296class Interval(TypeDecorator): 

297 """Base class representing a time interval.""" 

298 

299 impl = Text 

300 

301 cache_ok = True 

302 

303 attr_keys = { 

304 datetime.timedelta: ("days", "seconds", "microseconds"), 

305 relativedelta.relativedelta: ( 

306 "years", 

307 "months", 

308 "days", 

309 "leapdays", 

310 "hours", 

311 "minutes", 

312 "seconds", 

313 "microseconds", 

314 "year", 

315 "month", 

316 "day", 

317 "hour", 

318 "minute", 

319 "second", 

320 "microsecond", 

321 ), 

322 } 

323 

324 def process_bind_param(self, value, dialect): 

325 if isinstance(value, tuple(self.attr_keys)): 

326 attrs = {key: getattr(value, key) for key in self.attr_keys[type(value)]} 

327 return json.dumps({"type": type(value).__name__, "attrs": attrs}) 

328 return json.dumps(value) 

329 

330 def process_result_value(self, value, dialect): 

331 if not value: 

332 return value 

333 data = json.loads(value) 

334 if isinstance(data, dict): 

335 type_map = {key.__name__: key for key in self.attr_keys} 

336 return type_map[data["type"]](**data["attrs"]) 

337 return data 

338 

339 

340def nulls_first(col, session: Session) -> dict[str, Any]: 

341 """Specify *NULLS FIRST* to the column ordering. 

342 

343 This is only done to Postgres, currently the only backend that supports it. 

344 Other databases do not need it since NULL values are considered lower than 

345 any other values, and appear first when the order is ASC (ascending). 

346 """ 

347 if session.bind.dialect.name == "postgresql": 

348 return nullsfirst(col) 

349 else: 

350 return col 

351 

352 

353USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True) 

354 

355 

356def with_row_locks( 

357 query: Query, 

358 session: Session, 

359 *, 

360 nowait: bool = False, 

361 skip_locked: bool = False, 

362 **kwargs, 

363) -> Query: 

364 """ 

365 Apply with_for_update to the SQLAlchemy query if row level locking is in use. 

366 

367 This wrapper is needed so we don't use the syntax on unsupported database 

368 engines. In particular, MySQL (prior to 8.0) and MariaDB do not support 

369 row locking, where we do not support nor recommend running HA scheduler. If 

370 a user ignores this and tries anyway, everything will still work, just 

371 slightly slower in some circumstances. 

372 

373 See https://jira.mariadb.org/browse/MDEV-13115 

374 

375 :param query: An SQLAlchemy Query object 

376 :param session: ORM Session 

377 :param nowait: If set to True, will pass NOWAIT to supported database backends. 

378 :param skip_locked: If set to True, will pass SKIP LOCKED to supported database backends. 

379 :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc) 

380 :return: updated query 

381 """ 

382 dialect = session.bind.dialect 

383 

384 # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it. 

385 if not USE_ROW_LEVEL_LOCKING: 

386 return query 

387 if dialect.name == "mysql" and not dialect.supports_for_update_of: 

388 return query 

389 if nowait: 

390 kwargs["nowait"] = True 

391 if skip_locked: 

392 kwargs["skip_locked"] = True 

393 return query.with_for_update(**kwargs) 

394 

395 

396@contextlib.contextmanager 

397def lock_rows(query: Query, session: Session) -> Generator[None, None, None]: 

398 """Lock database rows during the context manager block. 

399 

400 This is a convenient method for ``with_row_locks`` when we don't need the 

401 locked rows. 

402 

403 :meta private: 

404 """ 

405 locked_rows = with_row_locks(query, session) 

406 yield 

407 del locked_rows 

408 

409 

410class CommitProhibitorGuard: 

411 """Context manager class that powers prohibit_commit.""" 

412 

413 expected_commit = False 

414 

415 def __init__(self, session: Session): 

416 self.session = session 

417 

418 def _validate_commit(self, _): 

419 if self.expected_commit: 

420 self.expected_commit = False 

421 return 

422 raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!") 

423 

424 def __enter__(self): 

425 event.listen(self.session, "before_commit", self._validate_commit) 

426 return self 

427 

428 def __exit__(self, *exc_info): 

429 event.remove(self.session, "before_commit", self._validate_commit) 

430 

431 def commit(self): 

432 """ 

433 Commit the session. 

434 

435 This is the required way to commit when the guard is in scope 

436 """ 

437 self.expected_commit = True 

438 self.session.commit() 

439 

440 

441def prohibit_commit(session): 

442 """ 

443 Return a context manager that will disallow any commit that isn't done via the context manager. 

444 

445 The aim of this is to ensure that transaction lifetime is strictly controlled which is especially 

446 important in the core scheduler loop. Any commit on the session that is _not_ via this context manager 

447 will result in RuntimeError 

448 

449 Example usage: 

450 

451 .. code:: python 

452 

453 with prohibit_commit(session) as guard: 

454 # ... do something with session 

455 guard.commit() 

456 

457 # This would throw an error 

458 # session.commit() 

459 """ 

460 return CommitProhibitorGuard(session) 

461 

462 

463def is_lock_not_available_error(error: OperationalError): 

464 """Check if the Error is about not being able to acquire lock.""" 

465 # DB specific error codes: 

466 # Postgres: 55P03 

467 # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT 

468 # is set.' 

469 # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction 

470 # (when NOWAIT isn't available) 

471 db_err_code = getattr(error.orig, "pgcode", None) or error.orig.args[0] 

472 

473 # We could test if error.orig is an instance of 

474 # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves 

475 # importing it. This doesn't 

476 if db_err_code in ("55P03", 1205, 3572): 

477 return True 

478 return False 

479 

480 

481@overload 

482def tuple_in_condition( 

483 columns: tuple[ColumnElement, ...], 

484 collection: Iterable[Any], 

485) -> ColumnOperators: ... 

486 

487 

488@overload 

489def tuple_in_condition( 

490 columns: tuple[ColumnElement, ...], 

491 collection: Select, 

492 *, 

493 session: Session, 

494) -> ColumnOperators: ... 

495 

496 

497def tuple_in_condition( 

498 columns: tuple[ColumnElement, ...], 

499 collection: Iterable[Any] | Select, 

500 *, 

501 session: Session | None = None, 

502) -> ColumnOperators: 

503 """ 

504 Generate a tuple-in-collection operator to use in ``.where()``. 

505 

506 For most SQL backends, this generates a simple ``([col, ...]) IN [condition]`` 

507 clause. 

508 

509 :meta private: 

510 """ 

511 return tuple_(*columns).in_(collection) 

512 

513 

514@overload 

515def tuple_not_in_condition( 

516 columns: tuple[ColumnElement, ...], 

517 collection: Iterable[Any], 

518) -> ColumnOperators: ... 

519 

520 

521@overload 

522def tuple_not_in_condition( 

523 columns: tuple[ColumnElement, ...], 

524 collection: Select, 

525 *, 

526 session: Session, 

527) -> ColumnOperators: ... 

528 

529 

530def tuple_not_in_condition( 

531 columns: tuple[ColumnElement, ...], 

532 collection: Iterable[Any] | Select, 

533 *, 

534 session: Session | None = None, 

535) -> ColumnOperators: 

536 """ 

537 Generate a tuple-not-in-collection operator to use in ``.where()``. 

538 

539 This is similar to ``tuple_in_condition`` except generating ``NOT IN``. 

540 

541 :meta private: 

542 """ 

543 return tuple_(*columns).not_in(collection) 

544 

545 

546def get_orm_mapper(): 

547 """Get the correct ORM mapper for the installed SQLAlchemy version.""" 

548 import sqlalchemy.orm.mapper 

549 

550 return sqlalchemy.orm.mapper if is_sqlalchemy_v1() else sqlalchemy.orm.Mapper 

551 

552 

553def is_sqlalchemy_v1() -> bool: 

554 return version.parse(metadata.version("sqlalchemy")).major == 1