Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/serde/__init__.py: 33%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

215 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import dataclasses 

20import enum 

21import functools 

22import logging 

23import re 

24import sys 

25from fnmatch import fnmatch 

26from importlib import import_module 

27from re import Pattern 

28from typing import TYPE_CHECKING, Any, TypeVar, cast, overload 

29 

30import attr 

31 

32from airflow.sdk._shared.module_loading import import_string, iter_namespace, qualname 

33from airflow.sdk.configuration import conf 

34from airflow.sdk.observability.stats import Stats 

35from airflow.sdk.serde.typing import is_pydantic_model 

36 

37if TYPE_CHECKING: 

38 from types import ModuleType 

39 

40log = logging.getLogger(__name__) 

41 

42MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1 

43 

44CLASSNAME = "__classname__" 

45VERSION = "__version__" 

46DATA = "__data__" 

47SCHEMA_ID = "__id__" 

48CACHE = "__cache__" 

49 

50OLD_TYPE = "__type" 

51OLD_SOURCE = "__source" 

52OLD_DATA = "__var" 

53OLD_DICT = "dict" 

54PYDANTIC_MODEL_QUALNAME = "pydantic.main.BaseModel" 

55 

56DEFAULT_VERSION = 0 

57 

58T = TypeVar("T", bool, float, int, dict, list, str, tuple, set) 

59U = bool | float | int | dict | list | str | tuple | set 

60S = list | tuple | set 

61 

62_serializers: dict[str, ModuleType] = {} 

63_deserializers: dict[str, ModuleType] = {} 

64_stringifiers: dict[str, ModuleType] = {} 

65_extra_allowed: set[str] = set() 

66 

67_primitives = (int, bool, float, str) 

68_builtin_collections = (frozenset, list, set, tuple) # dict is treated specially. 

69 

70 

71def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]: 

72 """Encode an object so it can be understood by the deserializer.""" 

73 return {CLASSNAME: cls, VERSION: version, DATA: data} 

74 

75 

76def decode(d: dict[str, Any]) -> tuple[str, int, Any]: 

77 classname = d[CLASSNAME] 

78 version = d[VERSION] 

79 

80 if not isinstance(classname, str) or not isinstance(version, int): 

81 raise ValueError(f"cannot decode {d!r}") 

82 

83 data = d.get(DATA) 

84 

85 return classname, version, data 

86 

87 

88@overload 

89def serialize(o: dict, depth: int = 0) -> dict: ... 

90@overload 

91def serialize(o: None, depth: int = 0) -> None: ... 

92@overload 

93def serialize(o: object, depth: int = 0) -> U | None: ... 

94 

95 

96def serialize(o: object, depth: int = 0) -> U | None: 

97 """ 

98 Serialize an object into a representation consisting only built-in types. 

99 

100 Primitives (int, float, bool, str) are returned as-is. Built-in collections 

101 are iterated over, where it is assumed that keys in a dict can be represented 

102 as str. 

103 

104 Values that are not of a built-in type are serialized if a serializer is 

105 found for them. The order in which serializers are used is 

106 

107 1. A ``serialize`` function provided by the object. 

108 2. A registered serializer in the namespace of ``airflow.sdk.serde.serializers`` 

109 3. Annotations from attr or dataclass. 

110 

111 Limitations: attr and dataclass objects can lose type information for nested objects 

112 as they do not store this when calling ``asdict``. This means that at deserialization values 

113 will be deserialized as a dict as opposed to reinstating the object. Provide 

114 your own serializer to work around this. 

115 

116 :param o: The object to serialize. 

117 :param depth: Private tracker for nested serialization. 

118 :raise TypeError: A serializer cannot be found. 

119 :raise RecursionError: The object is too nested for the function to handle. 

120 :return: A representation of ``o`` that consists of only built-in types. 

121 """ 

122 if depth == MAX_RECURSION_DEPTH: 

123 raise RecursionError("maximum recursion depth reached for serialization") 

124 

125 # None remains None 

126 if o is None: 

127 return o 

128 

129 if isinstance(o, list): 

130 return [serialize(d, depth + 1) for d in o] 

131 

132 if isinstance(o, dict): 

133 if CLASSNAME in o or SCHEMA_ID in o: 

134 raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize") 

135 

136 return {str(k): serialize(v, depth + 1) for k, v in o.items()} 

137 

138 cls = type(o) 

139 qn = qualname(o) 

140 classname = None 

141 

142 # Serialize namedtuple like tuples 

143 # We also override the classname returned by the builtin.py serializer. The classname 

144 # has to be "builtins.tuple", so that the deserializer can deserialize the object into tuple. 

145 if _is_namedtuple(o): 

146 qn = "builtins.tuple" 

147 classname = qn 

148 

149 if is_pydantic_model(o): 

150 # to match the generic Pydantic serializer and deserializer in _serializers and _deserializers 

151 qn = PYDANTIC_MODEL_QUALNAME 

152 # the actual Pydantic model class to encode 

153 classname = qualname(o) 

154 

155 # if there is a builtin serializer available use that 

156 if qn in _serializers: 

157 data, serialized_classname, version, is_serialized = _serializers[qn].serialize(o) 

158 if is_serialized: 

159 return encode(classname or serialized_classname, version, serialize(data, depth + 1)) 

160 

161 # primitive types are returned as is 

162 if isinstance(o, _primitives): 

163 if isinstance(o, enum.Enum): 

164 return o.value 

165 

166 return o 

167 

168 # custom serializers 

169 dct = { 

170 CLASSNAME: qn, 

171 VERSION: getattr(cls, "__version__", DEFAULT_VERSION), 

172 } 

173 

174 # object / class brings their own 

175 if hasattr(o, "serialize"): 

176 data = getattr(o, "serialize")() 

177 

178 # if we end up with a structure, ensure its values are serialized 

179 if isinstance(data, dict): 

180 data = serialize(data, depth + 1) 

181 

182 dct[DATA] = data 

183 return dct 

184 

185 # dataclasses 

186 if dataclasses.is_dataclass(cls): 

187 # fixme: unfortunately using asdict with nested dataclasses it looses information 

188 data = dataclasses.asdict(o) # type: ignore[call-overload] 

189 dct[DATA] = serialize(data, depth + 1) 

190 return dct 

191 

192 # attr annotated 

193 if attr.has(cls): 

194 # Only include attributes which we can pass back to the classes constructor 

195 data = attr.asdict(cast("attr.AttrsInstance", o), recurse=False, filter=lambda a, v: a.init) 

196 dct[DATA] = serialize(data, depth + 1) 

197 return dct 

198 

199 raise TypeError(f"cannot serialize object of type {cls}") 

200 

201 

202def deserialize(o: T | None, full=True, type_hint: Any = None) -> object: 

203 """ 

204 Deserialize an object of primitive type and uses an allow list to determine if a class can be loaded. 

205 

206 :param o: primitive to deserialize into an arbitrary object. 

207 :param full: if False it will return a stringified representation 

208 of an object and will not load any classes 

209 :param type_hint: if set it will be used to help determine what 

210 object to deserialize in. It does not override if another 

211 specification is found 

212 :return: object 

213 """ 

214 if o is None: 

215 return o 

216 

217 if isinstance(o, _primitives): 

218 return o 

219 

220 # tuples, sets are included here for backwards compatibility 

221 if isinstance(o, _builtin_collections): 

222 col = [deserialize(d) for d in o] 

223 if isinstance(o, tuple): 

224 return tuple(col) 

225 

226 if isinstance(o, set): 

227 return set(col) 

228 

229 return col 

230 

231 if not isinstance(o, dict): 

232 # if o is not a dict, then it's already deserialized 

233 # in this case we should return it as is 

234 return o 

235 

236 o = _convert(o) 

237 

238 # plain dict and no type hint 

239 if CLASSNAME not in o and not type_hint or VERSION not in o: 

240 return {str(k): deserialize(v, full) for k, v in o.items()} 

241 

242 # custom deserialization starts here 

243 cls: Any 

244 version = 0 

245 value: Any = None 

246 classname = "" 

247 

248 if type_hint: 

249 cls = type_hint 

250 classname = qualname(cls) 

251 version = 0 # type hinting always sets version to 0 

252 value = o 

253 

254 if CLASSNAME in o and VERSION in o: 

255 classname, version, value = decode(o) 

256 

257 if not classname: 

258 raise TypeError("classname cannot be empty") 

259 

260 # only return string representation 

261 if not full: 

262 return _stringify(classname, version, value) 

263 if not _match(classname) and classname not in _extra_allowed: 

264 raise ImportError( 

265 f"{classname} was not found in allow list for deserialization imports. " 

266 f"To allow it, add it to allowed_deserialization_classes in the configuration" 

267 ) 

268 

269 cls = import_string(classname) 

270 

271 # registered deserializer 

272 if classname in _deserializers: 

273 return _deserializers[classname].deserialize(cls, version, deserialize(value)) 

274 if is_pydantic_model(cls): 

275 if PYDANTIC_MODEL_QUALNAME in _deserializers: 

276 return _deserializers[PYDANTIC_MODEL_QUALNAME].deserialize(cls, version, deserialize(value)) 

277 

278 # class has deserialization function 

279 if hasattr(cls, "deserialize"): 

280 return getattr(cls, "deserialize")(deserialize(value), version) 

281 

282 # attr or dataclass 

283 if attr.has(cls) or dataclasses.is_dataclass(cls): 

284 class_version = getattr(cls, "__version__", 0) 

285 if int(version) > class_version: 

286 raise TypeError( 

287 "serialized version of %s is newer than module version (%s > %s)", 

288 classname, 

289 version, 

290 class_version, 

291 ) 

292 

293 deserialize_value = deserialize(value) 

294 if not isinstance(deserialize_value, dict): 

295 raise TypeError( 

296 f"deserialized value for {classname} is not a dict, got {type(deserialize_value)}" 

297 ) 

298 return cls(**deserialize_value) # type: ignore[operator] 

299 

300 # no deserializer available 

301 raise TypeError(f"No deserializer found for {classname}") 

302 

303 

304def _convert(old: dict) -> dict: 

305 """Convert an old style serialization to new style.""" 

306 if OLD_TYPE in old and OLD_DATA in old: 

307 # Return old style dicts directly as they do not need wrapping 

308 if old[OLD_TYPE] == OLD_DICT: 

309 return old[OLD_DATA] 

310 return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA]} 

311 

312 return old 

313 

314 

315def _match(classname: str) -> bool: 

316 """Check if the given classname matches a path pattern either using glob format or regexp format.""" 

317 return _match_glob(classname) or _match_regexp(classname) 

318 

319 

320@functools.cache 

321def _match_glob(classname: str): 

322 """Check if the given classname matches a pattern from allowed_deserialization_classes using glob syntax.""" 

323 patterns = _get_patterns() 

324 return any(fnmatch(classname, p.pattern) for p in patterns) 

325 

326 

327@functools.cache 

328def _match_regexp(classname: str): 

329 """Check if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp.""" 

330 patterns = _get_regexp_patterns() 

331 return any(p.match(classname) is not None for p in patterns) 

332 

333 

334def _stringify(classname: str, version: int, value: T | None) -> str: 

335 """ 

336 Convert a previously serialized object in a somewhat human-readable format. 

337 

338 This function is not designed to be exact, and will not extensively traverse 

339 the whole tree of an object. 

340 """ 

341 if classname in _stringifiers: 

342 return _stringifiers[classname].stringify(classname, version, value) 

343 

344 s = f"{classname}@version={version}(" 

345 if isinstance(value, _primitives): 

346 s += f"{value}" 

347 elif isinstance(value, _builtin_collections): 

348 # deserialized values can be != str 

349 s += ",".join(str(deserialize(value, full=False))) 

350 elif isinstance(value, dict): 

351 s += ",".join(f"{k}={deserialize(v, full=False)}" for k, v in value.items()) 

352 s += ")" 

353 

354 return s 

355 

356 

357def _is_namedtuple(cls: Any) -> bool: 

358 """ 

359 Return True if the class is a namedtuple. 

360 

361 Checking is done by attributes as it is significantly faster than 

362 using isinstance. 

363 """ 

364 return hasattr(cls, "_asdict") and hasattr(cls, "_fields") and hasattr(cls, "_field_defaults") 

365 

366 

367def _register(): 

368 """Register builtin serializers and deserializers for types that don't have any themselves.""" 

369 _serializers.clear() 

370 _deserializers.clear() 

371 _stringifiers.clear() 

372 

373 with Stats.timer("serde.load_serializers") as timer: 

374 serializers_module = import_module("airflow.sdk.serde.serializers") 

375 for _, module_name, _ in iter_namespace(serializers_module): 

376 module = import_module(module_name) 

377 for serializers in getattr(module, "serializers", ()): 

378 s_qualname = serializers if isinstance(serializers, str) else qualname(serializers) 

379 if s_qualname in _serializers and _serializers[s_qualname] != module: 

380 raise AttributeError( 

381 f"duplicate {s_qualname} for serialization in {module} and {_serializers[s_qualname]}" 

382 ) 

383 log.debug("registering %s for serialization", s_qualname) 

384 _serializers[s_qualname] = module 

385 for deserializers in getattr(module, "deserializers", ()): 

386 d_qualname = deserializers if isinstance(deserializers, str) else qualname(deserializers) 

387 if d_qualname in _deserializers and _deserializers[d_qualname] != module: 

388 raise AttributeError( 

389 f"duplicate {d_qualname} for deserialization in {module} and {_deserializers[d_qualname]}" 

390 ) 

391 log.debug("registering %s for deserialization", d_qualname) 

392 _deserializers[d_qualname] = module 

393 _extra_allowed.add(d_qualname) 

394 for stringifiers in getattr(module, "stringifiers", ()): 

395 c_qualname = stringifiers if isinstance(stringifiers, str) else qualname(stringifiers) 

396 if c_qualname in _deserializers and _deserializers[c_qualname] != module: 

397 raise AttributeError( 

398 f"duplicate {c_qualname} for stringifiers in {module} and {_stringifiers[c_qualname]}" 

399 ) 

400 log.debug("registering %s for stringifying", c_qualname) 

401 _stringifiers[c_qualname] = module 

402 

403 log.debug("loading serializers took %.3f seconds", timer.duration) 

404 

405 

406@functools.cache 

407def _get_patterns() -> list[Pattern]: 

408 return [re.compile(p) for p in conf.get("core", "allowed_deserialization_classes").split()] 

409 

410 

411@functools.cache 

412def _get_regexp_patterns() -> list[Pattern]: 

413 return [re.compile(p) for p in conf.get("core", "allowed_deserialization_classes_regexp").split()] 

414 

415 

416_register()