Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/utils/sqlalchemy.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

189 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import contextlib 

21import copy 

22import datetime 

23import logging 

24from collections.abc import Generator 

25from typing import TYPE_CHECKING 

26 

27from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst 

28from sqlalchemy.dialects import mysql 

29from sqlalchemy.dialects.postgresql import JSONB 

30from sqlalchemy.types import JSON, Text, TypeDecorator 

31 

32from airflow._shared.timezones.timezone import make_naive, utc 

33from airflow.configuration import conf 

34from airflow.serialization.enums import Encoding 

35 

36if TYPE_CHECKING: 

37 from collections.abc import Iterable 

38 

39 from kubernetes.client.models.v1_pod import V1Pod 

40 from sqlalchemy.exc import OperationalError 

41 from sqlalchemy.orm import Session 

42 from sqlalchemy.sql import Select 

43 from sqlalchemy.sql.elements import ColumnElement 

44 from sqlalchemy.types import TypeEngine 

45 

46 from airflow.typing_compat import Self 

47 

48 

49log = logging.getLogger(__name__) 

50 

51 

52def get_dialect_name(session: Session) -> str | None: 

53 """Safely get the name of the dialect associated with the given session.""" 

54 if (bind := session.get_bind()) is None: 

55 raise ValueError("No bind/engine is associated with the provided Session") 

56 return getattr(bind.dialect, "name", None) 

57 

58 

59class UtcDateTime(TypeDecorator): 

60 """ 

61 Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` option, with some differences. 

62 

63 - Never silently take naive :class:`~datetime.datetime`, instead it 

64 always raise :exc:`ValueError` unless time zone aware value. 

65 - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo` 

66 is always converted to UTC. 

67 - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`, 

68 it never return naive :class:`~datetime.datetime`, but time zone 

69 aware value, even with SQLite or MySQL. 

70 - Always returns TIMESTAMP in UTC. 

71 """ 

72 

73 impl = TIMESTAMP(timezone=True) 

74 

75 cache_ok = True 

76 

77 def process_bind_param(self, value, dialect): 

78 if not isinstance(value, datetime.datetime): 

79 if value is None: 

80 return None 

81 raise TypeError(f"expected datetime.datetime, not {value!r}") 

82 if value.tzinfo is None: 

83 raise ValueError("naive datetime is disallowed") 

84 if dialect.name == "mysql": 

85 # For mysql versions prior 8.0.19 we should send timestamps as naive values in UTC 

86 # see: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-literals.html 

87 return make_naive(value, timezone=utc) 

88 return value.astimezone(utc) 

89 

90 def process_result_value(self, value, dialect): 

91 """ 

92 Process DateTimes from the DB making sure to always return UTC. 

93 

94 Not using timezone.convert_to_utc as that converts to configured TIMEZONE 

95 while the DB might be running with some other setting. We assume UTC 

96 datetimes in the database. 

97 """ 

98 if value is not None: 

99 if value.tzinfo is None: 

100 value = value.replace(tzinfo=utc) 

101 else: 

102 value = value.astimezone(utc) 

103 

104 return value 

105 

106 def load_dialect_impl(self, dialect): 

107 if dialect.name == "mysql": 

108 return mysql.TIMESTAMP(fsp=6) 

109 return super().load_dialect_impl(dialect) 

110 

111 

112class ExtendedJSON(TypeDecorator): 

113 """ 

114 A version of the JSON column that uses the Airflow extended JSON serialization. 

115 

116 See airflow.serialization. 

117 """ 

118 

119 impl = Text 

120 

121 cache_ok = True 

122 

123 should_evaluate_none = True 

124 

125 def load_dialect_impl(self, dialect) -> TypeEngine: 

126 if dialect.name == "postgresql": 

127 return dialect.type_descriptor(JSONB) 

128 return dialect.type_descriptor(JSON) 

129 

130 def process_bind_param(self, value, dialect): 

131 from airflow.serialization.serialized_objects import BaseSerialization 

132 

133 if value is None: 

134 return None 

135 

136 return BaseSerialization.serialize(value) 

137 

138 def process_result_value(self, value, dialect): 

139 from airflow.serialization.serialized_objects import BaseSerialization 

140 

141 if value is None: 

142 return None 

143 

144 return BaseSerialization.deserialize(value) 

145 

146 

147def sanitize_for_serialization(obj: V1Pod): 

148 """ 

149 Convert pod to dict.... but *safely*. 

150 

151 When pod objects created with one k8s version are unpickled in a python 

152 env with a more recent k8s version (in which the object attrs may have 

153 changed) the unpickled obj may throw an error because the attr 

154 expected on new obj may not be there on the unpickled obj. 

155 

156 This function still converts the pod to a dict; the only difference is 

157 it populates missing attrs with None. You may compare with 

158 https://github.com/kubernetes-client/python/blob/5a96bbcbe21a552cc1f9cda13e0522fafb0dbac8/kubernetes/client/api_client.py#L202 

159 

160 If obj is None, return None. 

161 If obj is str, int, long, float, bool, return directly. 

162 If obj is datetime.datetime, datetime.date 

163 convert to string in iso8601 format. 

164 If obj is list, sanitize each element in the list. 

165 If obj is dict, return the dict. 

166 If obj is OpenAPI model, return the properties dict. 

167 

168 :param obj: The data to serialize. 

169 :return: The serialized form of data. 

170 

171 :meta private: 

172 """ 

173 if obj is None: 

174 return None 

175 if isinstance(obj, (float, bool, bytes, str, int)): 

176 return obj 

177 if isinstance(obj, list): 

178 return [sanitize_for_serialization(sub_obj) for sub_obj in obj] 

179 if isinstance(obj, tuple): 

180 return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj) 

181 if isinstance(obj, (datetime.datetime, datetime.date)): 

182 return obj.isoformat() 

183 

184 if isinstance(obj, dict): 

185 obj_dict = obj 

186 else: 

187 obj_dict = { 

188 obj.attribute_map[attr]: getattr(obj, attr) 

189 for attr, _ in obj.openapi_types.items() 

190 # below is the only line we change, and we just add default=None for getattr 

191 if getattr(obj, attr, None) is not None 

192 } 

193 

194 return {key: sanitize_for_serialization(val) for key, val in obj_dict.items()} 

195 

196 

197def ensure_pod_is_valid_after_unpickling(pod: V1Pod) -> V1Pod | None: 

198 """ 

199 Convert pod to json and back so that pod is safe. 

200 

201 The pod_override in executor_config is a V1Pod object. 

202 Such objects created with one k8s version, when unpickled in 

203 an env with upgraded k8s version, may blow up when 

204 `to_dict` is called, because openapi client code gen calls 

205 getattr on all attrs in openapi_types for each object, and when 

206 new attrs are added to that list, getattr will fail. 

207 

208 Here we re-serialize it to ensure it is not going to blow up. 

209 

210 :meta private: 

211 """ 

212 try: 

213 # if to_dict works, the pod is fine 

214 pod.to_dict() 

215 return pod 

216 except AttributeError: 

217 pass 

218 try: 

219 from kubernetes.client.models.v1_pod import V1Pod 

220 except ImportError: 

221 return None 

222 if not isinstance(pod, V1Pod): 

223 return None 

224 try: 

225 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator 

226 

227 # now we actually reserialize / deserialize the pod 

228 pod_dict = sanitize_for_serialization(pod) 

229 return PodGenerator.deserialize_model_dict(pod_dict) 

230 except Exception: 

231 return None 

232 

233 

234class ExecutorConfigType(PickleType): 

235 """ 

236 Adds special handling for K8s executor config. 

237 

238 If we unpickle a k8s object that was pickled under an earlier k8s library version, then 

239 the unpickled object may throw an error when to_dict is called. To be more tolerant of 

240 version changes we convert to JSON using Airflow's serializer before pickling. 

241 """ 

242 

243 cache_ok = True 

244 

245 def bind_processor(self, dialect): 

246 from airflow.serialization.serialized_objects import BaseSerialization 

247 

248 super_process = super().bind_processor(dialect) 

249 

250 def process(value): 

251 val_copy = copy.copy(value) 

252 if isinstance(val_copy, dict) and "pod_override" in val_copy: 

253 val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"]) 

254 return super_process(val_copy) 

255 

256 return process 

257 

258 def result_processor(self, dialect, coltype): 

259 from airflow.serialization.serialized_objects import BaseSerialization 

260 

261 super_process = super().result_processor(dialect, coltype) 

262 

263 def process(value): 

264 value = super_process(value) # unpickle 

265 

266 if isinstance(value, dict) and "pod_override" in value: 

267 pod_override = value["pod_override"] 

268 

269 if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE): 

270 # If pod_override was serialized with Airflow's BaseSerialization, deserialize it 

271 value["pod_override"] = BaseSerialization.deserialize(pod_override) 

272 else: 

273 # backcompat path 

274 # we no longer pickle raw pods but this code may be reached 

275 # when accessing executor configs created in a prior version 

276 new_pod = ensure_pod_is_valid_after_unpickling(pod_override) 

277 if new_pod: 

278 value["pod_override"] = new_pod 

279 return value 

280 

281 return process 

282 

283 def compare_values(self, x, y): 

284 """ 

285 Compare x and y using self.comparator if available. Else, use __eq__. 

286 

287 The TaskInstance.executor_config attribute is a pickled object that may contain kubernetes objects. 

288 

289 If the installed library version has changed since the object was originally pickled, 

290 due to the underlying ``__eq__`` method on these objects (which converts them to JSON), 

291 we may encounter attribute errors. In this case we should replace the stored object. 

292 

293 From https://github.com/apache/airflow/pull/24356 we use our serializer to store 

294 k8s objects, but there could still be raw pickled k8s objects in the database, 

295 stored from earlier version, so we still compare them defensively here. 

296 """ 

297 if self.comparator: 

298 return self.comparator(x, y) 

299 try: 

300 return x == y 

301 except AttributeError: 

302 return False 

303 

304 

305def nulls_first(col: ColumnElement, session: Session) -> ColumnElement: 

306 """ 

307 Specify *NULLS FIRST* to the column ordering. 

308 

309 This is only done to Postgres, currently the only backend that supports it. 

310 Other databases do not need it since NULL values are considered lower than 

311 any other values, and appear first when the order is ASC (ascending). 

312 """ 

313 if get_dialect_name(session) == "postgresql": 

314 return nullsfirst(col) 

315 return col 

316 

317 

318USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True) 

319 

320 

321def with_row_locks( 

322 query: Select, 

323 session: Session, 

324 *, 

325 nowait: bool = False, 

326 skip_locked: bool = False, 

327 key_share: bool = True, 

328 **kwargs, 

329) -> Select: 

330 """ 

331 Apply with_for_update to the SQLAlchemy query if row level locking is in use. 

332 

333 This wrapper is needed so we don't use the syntax on unsupported database 

334 engines. In particular, MySQL (prior to 8.0) and MariaDB do not support 

335 row locking, where we do not support nor recommend running HA scheduler. If 

336 a user ignores this and tries anyway, everything will still work, just 

337 slightly slower in some circumstances. 

338 

339 See https://jira.mariadb.org/browse/MDEV-13115 

340 

341 :param query: An SQLAlchemy Query object 

342 :param session: ORM Session 

343 :param nowait: If set to True, will pass NOWAIT to supported database backends. 

344 :param skip_locked: If set to True, will pass SKIP LOCKED to supported database backends. 

345 :param key_share: If true, will lock with FOR KEY SHARE UPDATE (at least on postgres). 

346 :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc) 

347 :return: updated query 

348 """ 

349 try: 

350 dialect_name = get_dialect_name(session) 

351 except ValueError: 

352 return query 

353 if not dialect_name: 

354 return query 

355 

356 # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it. 

357 if not USE_ROW_LEVEL_LOCKING: 

358 return query 

359 if dialect_name == "mysql" and not getattr( 

360 session.bind.dialect if session.bind else None, "supports_for_update_of", False 

361 ): 

362 return query 

363 if nowait: 

364 kwargs["nowait"] = True 

365 if skip_locked: 

366 kwargs["skip_locked"] = True 

367 if key_share: 

368 kwargs["key_share"] = True 

369 return query.with_for_update(**kwargs) 

370 

371 

372@contextlib.contextmanager 

373def lock_rows(query: Select, session: Session) -> Generator[None, None, None]: 

374 """ 

375 Lock database rows during the context manager block. 

376 

377 This is a convenient method for ``with_row_locks`` when we don't need the 

378 locked rows. 

379 

380 :meta private: 

381 """ 

382 locked_rows = with_row_locks(query, session) 

383 yield 

384 del locked_rows 

385 

386 

387class CommitProhibitorGuard: 

388 """Context manager class that powers prohibit_commit.""" 

389 

390 expected_commit = False 

391 

392 def __init__(self, session: Session): 

393 self.session = session 

394 

395 def _validate_commit(self, _): 

396 if self.expected_commit: 

397 self.expected_commit = False 

398 return 

399 raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!") 

400 

401 def __enter__(self) -> Self: 

402 event.listen(self.session, "before_commit", self._validate_commit) 

403 return self 

404 

405 def __exit__(self, *exc_info): 

406 event.remove(self.session, "before_commit", self._validate_commit) 

407 

408 def commit(self): 

409 """ 

410 Commit the session. 

411 

412 This is the required way to commit when the guard is in scope 

413 """ 

414 self.expected_commit = True 

415 self.session.commit() 

416 

417 

418def prohibit_commit(session): 

419 """ 

420 Return a context manager that will disallow any commit that isn't done via the context manager. 

421 

422 The aim of this is to ensure that transaction lifetime is strictly controlled which is especially 

423 important in the core scheduler loop. Any commit on the session that is _not_ via this context manager 

424 will result in RuntimeError 

425 

426 Example usage: 

427 

428 .. code:: python 

429 

430 with prohibit_commit(session) as guard: 

431 # ... do something with session 

432 guard.commit() 

433 

434 # This would throw an error 

435 # session.commit() 

436 """ 

437 return CommitProhibitorGuard(session) 

438 

439 

440def is_lock_not_available_error(error: OperationalError): 

441 """Check if the Error is about not being able to acquire lock.""" 

442 # DB specific error codes: 

443 # Postgres: 55P03 

444 # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT 

445 # is set.' 

446 # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction 

447 # (when NOWAIT isn't available) 

448 db_err_code = getattr(error.orig, "pgcode", None) or ( 

449 error.orig.args[0] if error.orig and error.orig.args else None 

450 ) 

451 

452 # We could test if error.orig is an instance of 

453 # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves 

454 # importing it. This doesn't 

455 if db_err_code in ("55P03", 1205, 3572): 

456 return True 

457 return False 

458 

459 

460def make_dialect_kwarg(dialect: str) -> dict[str, str | Iterable[str]]: 

461 """Create an SQLAlchemy-version-aware dialect keyword argument.""" 

462 return {"dialect_names": (dialect,)}