Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/utils/sqlalchemy.py: 35%

175 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import copy 

21import datetime 

22import json 

23import logging 

24from typing import Any, Iterable 

25 

26import pendulum 

27from dateutil import relativedelta 

28from sqlalchemy import TIMESTAMP, PickleType, and_, event, false, nullsfirst, or_, true, tuple_ 

29from sqlalchemy.dialects import mssql, mysql 

30from sqlalchemy.exc import OperationalError 

31from sqlalchemy.orm.session import Session 

32from sqlalchemy.sql import ColumnElement 

33from sqlalchemy.sql.expression import ColumnOperators 

34from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText 

35 

36from airflow import settings 

37from airflow.configuration import conf 

38from airflow.serialization.enums import Encoding 

39 

40log = logging.getLogger(__name__) 

41 

42utc = pendulum.tz.timezone("UTC") 

43 

44using_mysql = conf.get_mandatory_value("database", "sql_alchemy_conn").lower().startswith("mysql") 

45 

46 

47class UtcDateTime(TypeDecorator): 

48 """ 

49 Almost equivalent to :class:`~sqlalchemy.types.TIMESTAMP` with 

50 ``timezone=True`` option, but it differs from that by: 

51 

52 - Never silently take naive :class:`~datetime.datetime`, instead it 

53 always raise :exc:`ValueError` unless time zone aware value. 

54 - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo` 

55 is always converted to UTC. 

56 - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`, 

57 it never return naive :class:`~datetime.datetime`, but time zone 

58 aware value, even with SQLite or MySQL. 

59 - Always returns TIMESTAMP in UTC 

60 

61 """ 

62 

63 impl = TIMESTAMP(timezone=True) 

64 

65 cache_ok = True 

66 

67 def process_bind_param(self, value, dialect): 

68 if value is not None: 

69 if not isinstance(value, datetime.datetime): 

70 raise TypeError("expected datetime.datetime, not " + repr(value)) 

71 elif value.tzinfo is None: 

72 raise ValueError("naive datetime is disallowed") 

73 # For mysql we should store timestamps as naive values 

74 # Timestamp in MYSQL is not timezone aware. In MySQL 5.6 

75 # timezone added at the end is ignored but in MySQL 5.7 

76 # inserting timezone value fails with 'invalid-date' 

77 # See https://issues.apache.org/jira/browse/AIRFLOW-7001 

78 if using_mysql: 

79 from airflow.utils.timezone import make_naive 

80 

81 return make_naive(value, timezone=utc) 

82 return value.astimezone(utc) 

83 return None 

84 

85 def process_result_value(self, value, dialect): 

86 """ 

87 Processes DateTimes from the DB making sure it is always 

88 returning UTC. Not using timezone.convert_to_utc as that 

89 converts to configured TIMEZONE while the DB might be 

90 running with some other setting. We assume UTC datetimes 

91 in the database. 

92 """ 

93 if value is not None: 

94 if value.tzinfo is None: 

95 value = value.replace(tzinfo=utc) 

96 else: 

97 value = value.astimezone(utc) 

98 

99 return value 

100 

101 def load_dialect_impl(self, dialect): 

102 if dialect.name == "mssql": 

103 return mssql.DATETIME2(precision=6) 

104 elif dialect.name == "mysql": 

105 return mysql.TIMESTAMP(fsp=6) 

106 return super().load_dialect_impl(dialect) 

107 

108 

109class ExtendedJSON(TypeDecorator): 

110 """ 

111 A version of the JSON column that uses the Airflow extended JSON 

112 serialization provided by airflow.serialization. 

113 """ 

114 

115 impl = Text 

116 

117 cache_ok = True 

118 

119 def db_supports_json(self): 

120 """Checks if the database supports JSON (i.e. is NOT MSSQL)""" 

121 return not conf.get("database", "sql_alchemy_conn").startswith("mssql") 

122 

123 def load_dialect_impl(self, dialect) -> TypeEngine: 

124 if self.db_supports_json(): 

125 return dialect.type_descriptor(JSON) 

126 return dialect.type_descriptor(UnicodeText) 

127 

128 def process_bind_param(self, value, dialect): 

129 from airflow.serialization.serialized_objects import BaseSerialization 

130 

131 if value is None: 

132 return None 

133 

134 # First, encode it into our custom JSON-targeted dict format 

135 value = BaseSerialization.serialize(value) 

136 

137 # Then, if the database does not have native JSON support, encode it again as a string 

138 if not self.db_supports_json(): 

139 value = json.dumps(value) 

140 

141 return value 

142 

143 def process_result_value(self, value, dialect): 

144 from airflow.serialization.serialized_objects import BaseSerialization 

145 

146 if value is None: 

147 return None 

148 

149 # Deserialize from a string first if needed 

150 if not self.db_supports_json(): 

151 value = json.loads(value) 

152 

153 return BaseSerialization.deserialize(value) 

154 

155 

156class ExecutorConfigType(PickleType): 

157 """ 

158 Adds special handling for K8s executor config. If we unpickle a k8s object that was 

159 pickled under an earlier k8s library version, then the unpickled object may throw an error 

160 when to_dict is called. To be more tolerant of version changes we convert to JSON using 

161 Airflow's serializer before pickling. 

162 """ 

163 

164 cache_ok = True 

165 

166 def bind_processor(self, dialect): 

167 

168 from airflow.serialization.serialized_objects import BaseSerialization 

169 

170 super_process = super().bind_processor(dialect) 

171 

172 def process(value): 

173 val_copy = copy.copy(value) 

174 if isinstance(val_copy, dict) and "pod_override" in val_copy: 

175 val_copy["pod_override"] = BaseSerialization.serialize(val_copy["pod_override"]) 

176 return super_process(val_copy) 

177 

178 return process 

179 

180 def result_processor(self, dialect, coltype): 

181 from airflow.serialization.serialized_objects import BaseSerialization 

182 

183 super_process = super().result_processor(dialect, coltype) 

184 

185 def process(value): 

186 value = super_process(value) # unpickle 

187 

188 if isinstance(value, dict) and "pod_override" in value: 

189 pod_override = value["pod_override"] 

190 

191 # If pod_override was serialized with Airflow's BaseSerialization, deserialize it 

192 if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE): 

193 value["pod_override"] = BaseSerialization.deserialize(pod_override) 

194 return value 

195 

196 return process 

197 

198 def compare_values(self, x, y): 

199 """ 

200 The TaskInstance.executor_config attribute is a pickled object that may contain 

201 kubernetes objects. If the installed library version has changed since the 

202 object was originally pickled, due to the underlying ``__eq__`` method on these 

203 objects (which converts them to JSON), we may encounter attribute errors. In this 

204 case we should replace the stored object. 

205 

206 From https://github.com/apache/airflow/pull/24356 we use our serializer to store 

207 k8s objects, but there could still be raw pickled k8s objects in the database, 

208 stored from earlier version, so we still compare them defensively here. 

209 """ 

210 if self.comparator: 

211 return self.comparator(x, y) 

212 else: 

213 try: 

214 return x == y 

215 except AttributeError: 

216 return False 

217 

218 

219class Interval(TypeDecorator): 

220 """Base class representing a time interval.""" 

221 

222 impl = Text 

223 

224 cache_ok = True 

225 

226 attr_keys = { 

227 datetime.timedelta: ("days", "seconds", "microseconds"), 

228 relativedelta.relativedelta: ( 

229 "years", 

230 "months", 

231 "days", 

232 "leapdays", 

233 "hours", 

234 "minutes", 

235 "seconds", 

236 "microseconds", 

237 "year", 

238 "month", 

239 "day", 

240 "hour", 

241 "minute", 

242 "second", 

243 "microsecond", 

244 ), 

245 } 

246 

247 def process_bind_param(self, value, dialect): 

248 if isinstance(value, tuple(self.attr_keys)): 

249 attrs = {key: getattr(value, key) for key in self.attr_keys[type(value)]} 

250 return json.dumps({"type": type(value).__name__, "attrs": attrs}) 

251 return json.dumps(value) 

252 

253 def process_result_value(self, value, dialect): 

254 if not value: 

255 return value 

256 data = json.loads(value) 

257 if isinstance(data, dict): 

258 type_map = {key.__name__: key for key in self.attr_keys} 

259 return type_map[data["type"]](**data["attrs"]) 

260 return data 

261 

262 

263def skip_locked(session: Session) -> dict[str, Any]: 

264 """ 

265 Return kargs for passing to `with_for_update()` suitable for the current DB engine version. 

266 

267 We do this as we document the fact that on DB engines that don't support this construct, we do not 

268 support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still 

269 work, just slightly slower in some circumstances. 

270 

271 Specifically don't emit SKIP LOCKED for MySQL < 8, or MariaDB, neither of which support this construct 

272 

273 See https://jira.mariadb.org/browse/MDEV-13115 

274 """ 

275 dialect = session.bind.dialect 

276 

277 if dialect.name != "mysql" or dialect.supports_for_update_of: 

278 return {"skip_locked": True} 

279 else: 

280 return {} 

281 

282 

283def nowait(session: Session) -> dict[str, Any]: 

284 """ 

285 Return kwargs for passing to `with_for_update()` suitable for the current DB engine version. 

286 

287 We do this as we document the fact that on DB engines that don't support this construct, we do not 

288 support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still 

289 work, just slightly slower in some circumstances. 

290 

291 Specifically don't emit NOWAIT for MySQL < 8, or MariaDB, neither of which support this construct 

292 

293 See https://jira.mariadb.org/browse/MDEV-13115 

294 """ 

295 dialect = session.bind.dialect 

296 

297 if dialect.name != "mysql" or dialect.supports_for_update_of: 

298 return {"nowait": True} 

299 else: 

300 return {} 

301 

302 

303def nulls_first(col, session: Session) -> dict[str, Any]: 

304 """ 

305 Adds a nullsfirst construct to the column ordering. Currently only Postgres supports it. 

306 In MySQL & Sqlite NULL values are considered lower than any non-NULL value, therefore, NULL values 

307 appear first when the order is ASC (ascending) 

308 """ 

309 if session.bind.dialect.name == "postgresql": 

310 return nullsfirst(col) 

311 else: 

312 return col 

313 

314 

315USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True) 

316 

317 

318def with_row_locks(query, session: Session, **kwargs): 

319 """ 

320 Apply with_for_update to an SQLAlchemy query, if row level locking is in use. 

321 

322 :param query: An SQLAlchemy Query object 

323 :param session: ORM Session 

324 :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc) 

325 :return: updated query 

326 """ 

327 dialect = session.bind.dialect 

328 

329 # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it. 

330 if USE_ROW_LEVEL_LOCKING and (dialect.name != "mysql" or dialect.supports_for_update_of): 

331 return query.with_for_update(**kwargs) 

332 else: 

333 return query 

334 

335 

336class CommitProhibitorGuard: 

337 """Context manager class that powers prohibit_commit""" 

338 

339 expected_commit = False 

340 

341 def __init__(self, session: Session): 

342 self.session = session 

343 

344 def _validate_commit(self, _): 

345 if self.expected_commit: 

346 self.expected_commit = False 

347 return 

348 raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!") 

349 

350 def __enter__(self): 

351 event.listen(self.session, "before_commit", self._validate_commit) 

352 return self 

353 

354 def __exit__(self, *exc_info): 

355 event.remove(self.session, "before_commit", self._validate_commit) 

356 

357 def commit(self): 

358 """ 

359 Commit the session. 

360 

361 This is the required way to commit when the guard is in scope 

362 """ 

363 self.expected_commit = True 

364 self.session.commit() 

365 

366 

367def prohibit_commit(session): 

368 """ 

369 Return a context manager that will disallow any commit that isn't done via the context manager. 

370 

371 The aim of this is to ensure that transaction lifetime is strictly controlled which is especially 

372 important in the core scheduler loop. Any commit on the session that is _not_ via this context manager 

373 will result in RuntimeError 

374 

375 Example usage: 

376 

377 .. code:: python 

378 

379 with prohibit_commit(session) as guard: 

380 # ... do something with session 

381 guard.commit() 

382 

383 # This would throw an error 

384 # session.commit() 

385 """ 

386 return CommitProhibitorGuard(session) 

387 

388 

389def is_lock_not_available_error(error: OperationalError): 

390 """Check if the Error is about not being able to acquire lock""" 

391 # DB specific error codes: 

392 # Postgres: 55P03 

393 # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT 

394 # is set.' 

395 # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction 

396 # (when NOWAIT isn't available) 

397 db_err_code = getattr(error.orig, "pgcode", None) or error.orig.args[0] 

398 

399 # We could test if error.orig is an instance of 

400 # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves 

401 # importing it. This doesn't 

402 if db_err_code in ("55P03", 1205, 3572): 

403 return True 

404 return False 

405 

406 

407def tuple_in_condition( 

408 columns: tuple[ColumnElement, ...], 

409 collection: Iterable[Any], 

410) -> ColumnOperators: 

411 """Generates a tuple-in-collection operator to use in ``.filter()``. 

412 

413 For most SQL backends, this generates a simple ``([col, ...]) IN [condition]`` 

414 clause. This however does not work with MSSQL, where we need to expand to 

415 ``(c1 = v1a AND c2 = v2a ...) OR (c1 = v1b AND c2 = v2b ...) ...`` manually. 

416 

417 :meta private: 

418 """ 

419 if settings.engine.dialect.name != "mssql": 

420 return tuple_(*columns).in_(collection) 

421 clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in collection] 

422 if not clauses: 

423 return false() 

424 return or_(*clauses) 

425 

426 

427def tuple_not_in_condition( 

428 columns: tuple[ColumnElement, ...], 

429 collection: Iterable[Any], 

430) -> ColumnOperators: 

431 """Generates a tuple-not-in-collection operator to use in ``.filter()``. 

432 

433 This is similar to ``tuple_in_condition`` except generating ``NOT IN``. 

434 

435 :meta private: 

436 """ 

437 if settings.engine.dialect.name != "mssql": 

438 return tuple_(*columns).not_in(collection) 

439 clauses = [or_(*(c != v for c, v in zip(columns, values))) for values in collection] 

440 if not clauses: 

441 return true() 

442 return and_(*clauses)