Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/utils/sqlalchemy.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

202 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import contextlib 

21import copy 

22import datetime 

23import logging 

24from collections.abc import Generator 

25from importlib import metadata 

26from typing import TYPE_CHECKING, Any 

27 

28from packaging import version 

29from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst 

30from sqlalchemy.dialects import mysql 

31from sqlalchemy.dialects.postgresql import JSONB 

32from sqlalchemy.types import JSON, Text, TypeDecorator 

33 

34from airflow._shared.timezones.timezone import make_naive, utc 

35from airflow.configuration import conf 

36from airflow.serialization.enums import Encoding 

37 

38if TYPE_CHECKING: 

39 from collections.abc import Iterable 

40 

41 from kubernetes.client.models.v1_pod import V1Pod 

42 from sqlalchemy.exc import OperationalError 

43 from sqlalchemy.orm import Session 

44 from sqlalchemy.sql import Select 

45 from sqlalchemy.sql.elements import ColumnElement 

46 from sqlalchemy.types import TypeEngine 

47 

48 from airflow.typing_compat import Self 

49 

50 

51log = logging.getLogger(__name__) 

52 

53try: 

54 from sqlalchemy.orm import mapped_column 

55except ImportError: 

56 # fallback for SQLAlchemy < 2.0 

57 def mapped_column(*args, **kwargs): # type: ignore[misc] 

58 from sqlalchemy import Column 

59 

60 return Column(*args, **kwargs) 

61 

62 

63def get_dialect_name(session: Session) -> str | None: 

64 """Safely get the name of the dialect associated with the given session.""" 

65 if (bind := session.get_bind()) is None: 

66 raise ValueError("No bind/engine is associated with the provided Session") 

67 return getattr(bind.dialect, "name", None) 

68 

69 

70class UtcDateTime(TypeDecorator): 

71 """ 

72 Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` option, with some differences. 

73 

74 - Never silently take naive :class:`~datetime.datetime`, instead it 

75 always raise :exc:`ValueError` unless time zone aware value. 

76 - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo` 

77 is always converted to UTC. 

78 - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`, 

79 it never return naive :class:`~datetime.datetime`, but time zone 

80 aware value, even with SQLite or MySQL. 

81 - Always returns TIMESTAMP in UTC. 

82 """ 

83 

84 impl = TIMESTAMP(timezone=True) 

85 

86 cache_ok = True 

87 

88 def process_bind_param(self, value, dialect): 

89 if not isinstance(value, datetime.datetime): 

90 if value is None: 

91 return None 

92 raise TypeError(f"expected datetime.datetime, not {value!r}") 

93 if value.tzinfo is None: 

94 raise ValueError("naive datetime is disallowed") 

95 if dialect.name == "mysql": 

96 # For mysql versions prior 8.0.19 we should send timestamps as naive values in UTC 

97 # see: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-literals.html 

98 return make_naive(value, timezone=utc) 

99 return value.astimezone(utc) 

100 

101 def process_result_value(self, value, dialect): 

102 """ 

103 Process DateTimes from the DB making sure to always return UTC. 

104 

105 Not using timezone.convert_to_utc as that converts to configured TIMEZONE 

106 while the DB might be running with some other setting. We assume UTC 

107 datetimes in the database. 

108 """ 

109 if value is not None: 

110 if value.tzinfo is None: 

111 value = value.replace(tzinfo=utc) 

112 else: 

113 value = value.astimezone(utc) 

114 

115 return value 

116 

117 def load_dialect_impl(self, dialect): 

118 if dialect.name == "mysql": 

119 return mysql.TIMESTAMP(fsp=6) 

120 return super().load_dialect_impl(dialect) 

121 

122 

123class ExtendedJSON(TypeDecorator): 

124 """ 

125 A version of the JSON column that uses the Airflow extended JSON serialization. 

126 

127 See airflow.serialization. 

128 """ 

129 

130 impl = Text 

131 

132 cache_ok = True 

133 

134 should_evaluate_none = True 

135 

136 def load_dialect_impl(self, dialect) -> TypeEngine: 

137 if dialect.name == "postgresql": 

138 return dialect.type_descriptor(JSONB) 

139 return dialect.type_descriptor(JSON) 

140 

141 def process_bind_param(self, value, dialect): 

142 from airflow.serialization.serialized_objects import BaseSerialization 

143 

144 if value is None: 

145 return None 

146 

147 return BaseSerialization.serialize(value) 

148 

149 def process_result_value(self, value, dialect): 

150 from airflow.serialization.serialized_objects import BaseSerialization 

151 

152 if value is None: 

153 return None 

154 

155 return BaseSerialization.deserialize(value) 

156 

157 

158def sanitize_for_serialization(obj: V1Pod): 

159 """ 

160 Convert pod to dict.... but *safely*. 

161 

162 When pod objects created with one k8s version are unpickled in a python 

163 env with a more recent k8s version (in which the object attrs may have 

164 changed) the unpickled obj may throw an error because the attr 

165 expected on new obj may not be there on the unpickled obj. 

166 

167 This function still converts the pod to a dict; the only difference is 

168 it populates missing attrs with None. You may compare with 

169 https://github.com/kubernetes-client/python/blob/5a96bbcbe21a552cc1f9cda13e0522fafb0dbac8/kubernetes/client/api_client.py#L202 

170 

171 If obj is None, return None. 

172 If obj is str, int, long, float, bool, return directly. 

173 If obj is datetime.datetime, datetime.date 

174 convert to string in iso8601 format. 

175 If obj is list, sanitize each element in the list. 

176 If obj is dict, return the dict. 

177 If obj is OpenAPI model, return the properties dict. 

178 

179 :param obj: The data to serialize. 

180 :return: The serialized form of data. 

181 

182 :meta private: 

183 """ 

184 if obj is None: 

185 return None 

186 if isinstance(obj, (float, bool, bytes, str, int)): 

187 return obj 

188 if isinstance(obj, list): 

189 return [sanitize_for_serialization(sub_obj) for sub_obj in obj] 

190 if isinstance(obj, tuple): 

191 return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj) 

192 if isinstance(obj, (datetime.datetime, datetime.date)): 

193 return obj.isoformat() 

194 

195 if isinstance(obj, dict): 

196 obj_dict = obj 

197 else: 

198 obj_dict = { 

199 obj.attribute_map[attr]: getattr(obj, attr) 

200 for attr, _ in obj.openapi_types.items() 

201 # below is the only line we change, and we just add default=None for getattr 

202 if getattr(obj, attr, None) is not None 

203 } 

204 

205 return {key: sanitize_for_serialization(val) for key, val in obj_dict.items()} 

206 

207 

208def ensure_pod_is_valid_after_unpickling(pod: V1Pod) -> V1Pod | None: 

209 """ 

210 Convert pod to json and back so that pod is safe. 

211 

212 The pod_override in executor_config is a V1Pod object. 

213 Such objects created with one k8s version, when unpickled in 

214 an env with upgraded k8s version, may blow up when 

215 `to_dict` is called, because openapi client code gen calls 

216 getattr on all attrs in openapi_types for each object, and when 

217 new attrs are added to that list, getattr will fail. 

218 

219 Here we re-serialize it to ensure it is not going to blow up. 

220 

221 :meta private: 

222 """ 

223 try: 

224 # if to_dict works, the pod is fine 

225 pod.to_dict() 

226 return pod 

227 except AttributeError: 

228 pass 

229 try: 

230 from kubernetes.client.models.v1_pod import V1Pod 

231 except ImportError: 

232 return None 

233 if not isinstance(pod, V1Pod): 

234 return None 

235 try: 

236 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator 

237 

238 # now we actually reserialize / deserialize the pod 

239 pod_dict = sanitize_for_serialization(pod) 

240 return PodGenerator.deserialize_model_dict(pod_dict) 

241 except Exception: 

242 return None 

243 

244 

245class ExecutorConfigType(PickleType): 

246 """ 

247 Adds special handling for K8s executor config. 

248 

249 If we unpickle a k8s object that was pickled under an earlier k8s library version, then 

250 the unpickled object may throw an error when to_dict is called. To be more tolerant of 

251 version changes we convert to JSON using Airflow's serializer before pickling. 

252 """ 

253 

254 cache_ok = True 

255 

256 def bind_processor(self, dialect): 

257 from airflow.serialization.serialized_objects import BaseSerialization 

258 

259 super_process = super().bind_processor(dialect) 

260 

261 def process(value): 

262 val_copy = copy.copy(value) 

263 if isinstance(val_copy, dict) and "pod_override" in val_copy: 

264 val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"]) 

265 return super_process(val_copy) 

266 

267 return process 

268 

269 def result_processor(self, dialect, coltype): 

270 from airflow.serialization.serialized_objects import BaseSerialization 

271 

272 super_process = super().result_processor(dialect, coltype) 

273 

274 def process(value): 

275 value = super_process(value) # unpickle 

276 

277 if isinstance(value, dict) and "pod_override" in value: 

278 pod_override = value["pod_override"] 

279 

280 if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE): 

281 # If pod_override was serialized with Airflow's BaseSerialization, deserialize it 

282 value["pod_override"] = BaseSerialization.deserialize(pod_override) 

283 else: 

284 # backcompat path 

285 # we no longer pickle raw pods but this code may be reached 

286 # when accessing executor configs created in a prior version 

287 new_pod = ensure_pod_is_valid_after_unpickling(pod_override) 

288 if new_pod: 

289 value["pod_override"] = new_pod 

290 return value 

291 

292 return process 

293 

294 def compare_values(self, x, y): 

295 """ 

296 Compare x and y using self.comparator if available. Else, use __eq__. 

297 

298 The TaskInstance.executor_config attribute is a pickled object that may contain kubernetes objects. 

299 

300 If the installed library version has changed since the object was originally pickled, 

301 due to the underlying ``__eq__`` method on these objects (which converts them to JSON), 

302 we may encounter attribute errors. In this case we should replace the stored object. 

303 

304 From https://github.com/apache/airflow/pull/24356 we use our serializer to store 

305 k8s objects, but there could still be raw pickled k8s objects in the database, 

306 stored from earlier version, so we still compare them defensively here. 

307 """ 

308 if self.comparator: 

309 return self.comparator(x, y) 

310 try: 

311 return x == y 

312 except AttributeError: 

313 return False 

314 

315 

316def nulls_first(col: ColumnElement, session: Session) -> ColumnElement: 

317 """ 

318 Specify *NULLS FIRST* to the column ordering. 

319 

320 This is only done to Postgres, currently the only backend that supports it. 

321 Other databases do not need it since NULL values are considered lower than 

322 any other values, and appear first when the order is ASC (ascending). 

323 """ 

324 if get_dialect_name(session) == "postgresql": 

325 return nullsfirst(col) 

326 return col 

327 

328 

329USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True) 

330 

331 

332def with_row_locks( 

333 query: Select[Any], 

334 session: Session, 

335 *, 

336 nowait: bool = False, 

337 skip_locked: bool = False, 

338 key_share: bool = True, 

339 **kwargs, 

340) -> Select[Any]: 

341 """ 

342 Apply with_for_update to the SQLAlchemy query if row level locking is in use. 

343 

344 This wrapper is needed so we don't use the syntax on unsupported database 

345 engines. In particular, MySQL (prior to 8.0) and MariaDB do not support 

346 row locking, where we do not support nor recommend running HA scheduler. If 

347 a user ignores this and tries anyway, everything will still work, just 

348 slightly slower in some circumstances. 

349 

350 See https://jira.mariadb.org/browse/MDEV-13115 

351 

352 :param query: An SQLAlchemy Query object 

353 :param session: ORM Session 

354 :param nowait: If set to True, will pass NOWAIT to supported database backends. 

355 :param skip_locked: If set to True, will pass SKIP LOCKED to supported database backends. 

356 :param key_share: If true, will lock with FOR KEY SHARE UPDATE (at least on postgres). 

357 :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc) 

358 :return: updated query 

359 """ 

360 try: 

361 dialect_name = get_dialect_name(session) 

362 except ValueError: 

363 return query 

364 if not dialect_name: 

365 return query 

366 

367 # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it. 

368 if not USE_ROW_LEVEL_LOCKING: 

369 return query 

370 if dialect_name == "mysql" and not getattr( 

371 session.bind.dialect if session.bind else None, "supports_for_update_of", False 

372 ): 

373 return query 

374 if nowait: 

375 kwargs["nowait"] = True 

376 if skip_locked: 

377 kwargs["skip_locked"] = True 

378 if key_share: 

379 kwargs["key_share"] = True 

380 return query.with_for_update(**kwargs) 

381 

382 

383@contextlib.contextmanager 

384def lock_rows(query: Select, session: Session) -> Generator[None, None, None]: 

385 """ 

386 Lock database rows during the context manager block. 

387 

388 This is a convenient method for ``with_row_locks`` when we don't need the 

389 locked rows. 

390 

391 :meta private: 

392 """ 

393 locked_rows = with_row_locks(query, session) 

394 yield 

395 del locked_rows 

396 

397 

398class CommitProhibitorGuard: 

399 """Context manager class that powers prohibit_commit.""" 

400 

401 expected_commit = False 

402 

403 def __init__(self, session: Session): 

404 self.session = session 

405 

406 def _validate_commit(self, _): 

407 if self.expected_commit: 

408 self.expected_commit = False 

409 return 

410 raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!") 

411 

412 def __enter__(self) -> Self: 

413 event.listen(self.session, "before_commit", self._validate_commit) 

414 return self 

415 

416 def __exit__(self, *exc_info): 

417 event.remove(self.session, "before_commit", self._validate_commit) 

418 

419 def commit(self): 

420 """ 

421 Commit the session. 

422 

423 This is the required way to commit when the guard is in scope 

424 """ 

425 self.expected_commit = True 

426 self.session.commit() 

427 

428 

429def prohibit_commit(session): 

430 """ 

431 Return a context manager that will disallow any commit that isn't done via the context manager. 

432 

433 The aim of this is to ensure that transaction lifetime is strictly controlled which is especially 

434 important in the core scheduler loop. Any commit on the session that is _not_ via this context manager 

435 will result in RuntimeError 

436 

437 Example usage: 

438 

439 .. code:: python 

440 

441 with prohibit_commit(session) as guard: 

442 # ... do something with session 

443 guard.commit() 

444 

445 # This would throw an error 

446 # session.commit() 

447 """ 

448 return CommitProhibitorGuard(session) 

449 

450 

451def is_lock_not_available_error(error: OperationalError): 

452 """Check if the Error is about not being able to acquire lock.""" 

453 # DB specific error codes: 

454 # Postgres: 55P03 

455 # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT 

456 # is set.' 

457 # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction 

458 # (when NOWAIT isn't available) 

459 db_err_code = getattr(error.orig, "pgcode", None) or ( 

460 error.orig.args[0] if error.orig and error.orig.args else None 

461 ) 

462 

463 # We could test if error.orig is an instance of 

464 # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves 

465 # importing it. This doesn't 

466 if db_err_code in ("55P03", 1205, 3572): 

467 return True 

468 return False 

469 

470 

471def get_orm_mapper(): 

472 """Get the correct ORM mapper for the installed SQLAlchemy version.""" 

473 import sqlalchemy.orm.mapper 

474 

475 return sqlalchemy.orm.mapper if is_sqlalchemy_v1() else sqlalchemy.orm.Mapper 

476 

477 

478def is_sqlalchemy_v1() -> bool: 

479 return version.parse(metadata.version("sqlalchemy")).major == 1 

480 

481 

482def make_dialect_kwarg(dialect: str) -> dict[str, str | Iterable[str]]: 

483 """Create an SQLAlchemy-version-aware dialect keyword argument.""" 

484 return {"dialect_name": dialect} if is_sqlalchemy_v1() else {"dialect_names": (dialect,)}