Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/serialization/serde.py: 38%

185 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import dataclasses 

21import enum 

22import functools 

23import logging 

24import re 

25import sys 

26from importlib import import_module 

27from types import ModuleType 

28from typing import Any, TypeVar, Union, cast 

29 

30import attr 

31 

32import airflow.serialization.serializers 

33from airflow.configuration import conf 

34from airflow.stats import Stats 

35from airflow.utils.module_loading import import_string, iter_namespace, qualname 

36 

37log = logging.getLogger(__name__) 

38 

39MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1 

40 

41CLASSNAME = "__classname__" 

42VERSION = "__version__" 

43DATA = "__data__" 

44SCHEMA_ID = "__id__" 

45CACHE = "__cache__" 

46 

47OLD_TYPE = "__type" 

48OLD_SOURCE = "__source" 

49OLD_DATA = "__var" 

50 

51DEFAULT_VERSION = 0 

52 

53T = TypeVar("T", bool, float, int, dict, list, str, tuple, set) 

54U = Union[bool, float, int, dict, list, str, tuple, set] 

55S = Union[list, tuple, set] 

56 

57_serializers: dict[str, ModuleType] = {} 

58_deserializers: dict[str, ModuleType] = {} 

59_stringifiers: dict[str, ModuleType] = {} 

60_extra_allowed: set[str] = set() 

61 

62_primitives = (int, bool, float, str) 

63_builtin_collections = (frozenset, list, set, tuple) # dict is treated specially. 

64 

65 

66def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]: 

67 """Encodes o so it can be understood by the deserializer.""" 

68 return {CLASSNAME: cls, VERSION: version, DATA: data} 

69 

70 

71def decode(d: dict[str, Any]) -> tuple[str, int, Any]: 

72 classname = d[CLASSNAME] 

73 version = d[VERSION] 

74 

75 if not isinstance(classname, str) or not isinstance(version, int): 

76 raise ValueError(f"cannot decode {d!r}") 

77 

78 data = d.get(DATA) 

79 

80 return classname, version, data 

81 

82 

83def serialize(o: object, depth: int = 0) -> U | None: 

84 """Serialize an object into a representation consisting only built-in types. 

85 

86 Primitives (int, float, bool, str) are returned as-is. Built-in collections 

87 are iterated over, where it is assumed that keys in a dict can be represented 

88 as str. 

89 

90 Values that are not of a built-in type are serialized if a serializer is 

91 found for them. The order in which serializers are used is 

92 

93 1. A ``serialize`` function provided by the object. 

94 2. A registered serializer in the namespace of ``airflow.serialization.serializers`` 

95 3. Annotations from attr or dataclass. 

96 

97 Limitations: attr and dataclass objects can lose type information for nested objects 

98 as they do not store this when calling ``asdict``. This means that at deserialization values 

99 will be deserialized as a dict as opposed to reinstating the object. Provide 

100 your own serializer to work around this. 

101 

102 :param o: The object to serialize. 

103 :param depth: Private tracker for nested serialization. 

104 :raise TypeError: A serializer cannot be found. 

105 :raise RecursionError: The object is too nested for the function to handle. 

106 :return: A representation of ``o`` that consists of only built-in types. 

107 """ 

108 if depth == MAX_RECURSION_DEPTH: 

109 raise RecursionError("maximum recursion depth reached for serialization") 

110 

111 # None remains None 

112 if o is None: 

113 return o 

114 

115 # primitive types are returned as is 

116 if isinstance(o, _primitives): 

117 if isinstance(o, enum.Enum): 

118 return o.value 

119 

120 return o 

121 

122 if isinstance(o, list): 

123 return [serialize(d, depth + 1) for d in o] 

124 

125 if isinstance(o, dict): 

126 if CLASSNAME in o or SCHEMA_ID in o: 

127 raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize") 

128 

129 return {str(k): serialize(v, depth + 1) for k, v in o.items()} 

130 

131 cls = type(o) 

132 qn = qualname(o) 

133 

134 # custom serializers 

135 dct = { 

136 CLASSNAME: qn, 

137 VERSION: getattr(cls, "__version__", DEFAULT_VERSION), 

138 } 

139 

140 # if there is a builtin serializer available use that 

141 if qn in _serializers: 

142 data, classname, version, is_serialized = _serializers[qn].serialize(o) 

143 if is_serialized: 

144 return encode(classname, version, serialize(data, depth + 1)) 

145 

146 # object / class brings their own 

147 if hasattr(o, "serialize"): 

148 data = getattr(o, "serialize")() 

149 

150 # if we end up with a structure, ensure its values are serialized 

151 if isinstance(data, dict): 

152 data = serialize(data, depth + 1) 

153 

154 dct[DATA] = data 

155 return dct 

156 

157 # pydantic models are recursive 

158 if _is_pydantic(cls): 

159 data = o.dict() # type: ignore[attr-defined] 

160 dct[DATA] = serialize(data, depth + 1) 

161 return dct 

162 

163 # dataclasses 

164 if dataclasses.is_dataclass(cls): 

165 # fixme: unfortunately using asdict with nested dataclasses it looses information 

166 data = dataclasses.asdict(o) # type: ignore[call-overload] 

167 dct[DATA] = serialize(data, depth + 1) 

168 return dct 

169 

170 # attr annotated 

171 if attr.has(cls): 

172 # Only include attributes which we can pass back to the classes constructor 

173 data = attr.asdict(cast(attr.AttrsInstance, o), recurse=True, filter=lambda a, v: a.init) 

174 dct[DATA] = serialize(data, depth + 1) 

175 return dct 

176 

177 raise TypeError(f"cannot serialize object of type {cls}") 

178 

179 

180def deserialize(o: T | None, full=True, type_hint: Any = None) -> object: 

181 """ 

182 Deserializes an object of primitive type T into an object. Uses an allow 

183 list to determine if a class can be loaded. 

184 

185 :param o: primitive to deserialize into an arbitrary object. 

186 :param full: if False it will return a stringified representation 

187 of an object and will not load any classes 

188 :param type_hint: if set it will be used to help determine what 

189 object to deserialize in. It does not override if another 

190 specification is found 

191 :return: object 

192 """ 

193 if o is None: 

194 return o 

195 

196 if isinstance(o, _primitives): 

197 return o 

198 

199 # tuples, sets are included here for backwards compatibility 

200 if isinstance(o, _builtin_collections): 

201 col = [deserialize(d) for d in o] 

202 if isinstance(o, tuple): 

203 return tuple(col) 

204 

205 if isinstance(o, set): 

206 return set(col) 

207 

208 return col 

209 

210 if not isinstance(o, dict): 

211 # if o is not a dict, then it's already deserialized 

212 # in this case we should return it as is 

213 return o 

214 

215 o = _convert(o) 

216 

217 # plain dict and no type hint 

218 if CLASSNAME not in o and not type_hint or VERSION not in o: 

219 return {str(k): deserialize(v, full) for k, v in o.items()} 

220 

221 # custom deserialization starts here 

222 cls: Any 

223 version = 0 

224 value: Any = None 

225 classname = "" 

226 

227 if type_hint: 

228 cls = type_hint 

229 classname = qualname(cls) 

230 version = 0 # type hinting always sets version to 0 

231 value = o 

232 

233 if CLASSNAME in o and VERSION in o: 

234 classname, version, value = decode(o) 

235 

236 if not classname: 

237 raise TypeError("classname cannot be empty") 

238 

239 # only return string representation 

240 if not full: 

241 return _stringify(classname, version, value) 

242 

243 if not _match(classname) and classname not in _extra_allowed: 

244 raise ImportError( 

245 f"{classname} was not found in allow list for deserialization imports. " 

246 f"To allow it, add it to allowed_deserialization_classes in the configuration" 

247 ) 

248 

249 cls = import_string(classname) 

250 

251 # registered deserializer 

252 if classname in _deserializers: 

253 return _deserializers[classname].deserialize(classname, version, deserialize(value)) 

254 

255 # class has deserialization function 

256 if hasattr(cls, "deserialize"): 

257 return getattr(cls, "deserialize")(deserialize(value), version) 

258 

259 # attr or dataclass or pydantic 

260 if attr.has(cls) or dataclasses.is_dataclass(cls) or _is_pydantic(cls): 

261 class_version = getattr(cls, "__version__", 0) 

262 if int(version) > class_version: 

263 raise TypeError( 

264 "serialized version of %s is newer than module version (%s > %s)", 

265 classname, 

266 version, 

267 class_version, 

268 ) 

269 

270 return cls(**deserialize(value)) 

271 

272 # no deserializer available 

273 raise TypeError(f"No deserializer found for {classname}") 

274 

275 

276def _convert(old: dict) -> dict: 

277 """Converts an old style serialization to new style.""" 

278 if OLD_TYPE in old and OLD_DATA in old: 

279 return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA][OLD_DATA]} 

280 

281 return old 

282 

283 

284def _match(classname: str) -> bool: 

285 return any(p.match(classname) is not None for p in _get_patterns()) 

286 

287 

288def _stringify(classname: str, version: int, value: T | None) -> str: 

289 """Convert a previously serialized object in a somewhat human-readable format. 

290 

291 This function is not designed to be exact, and will not extensively traverse 

292 the whole tree of an object. 

293 """ 

294 if classname in _stringifiers: 

295 return _stringifiers[classname].stringify(classname, version, value) 

296 

297 s = f"{classname}@version={version}(" 

298 if isinstance(value, _primitives): 

299 s += f"{value})" 

300 elif isinstance(value, _builtin_collections): 

301 # deserialized values can be != str 

302 s += ",".join(str(deserialize(value, full=False))) 

303 elif isinstance(value, dict): 

304 for k, v in value.items(): 

305 s += f"{k}={deserialize(v, full=False)}," 

306 s = s[:-1] + ")" 

307 

308 return s 

309 

310 

311def _is_pydantic(cls: Any) -> bool: 

312 """Return True if the class is a pydantic model. 

313 

314 Checking is done by attributes as it is significantly faster than 

315 using isinstance. 

316 """ 

317 return hasattr(cls, "__validators__") and hasattr(cls, "__fields__") and hasattr(cls, "dict") 

318 

319 

320def _register(): 

321 """Register builtin serializers and deserializers for types that don't have any themselves.""" 

322 _serializers.clear() 

323 _deserializers.clear() 

324 _stringifiers.clear() 

325 

326 with Stats.timer("serde.load_serializers") as timer: 

327 for _, name, _ in iter_namespace(airflow.serialization.serializers): 

328 name = import_module(name) 

329 for s in getattr(name, "serializers", ()): 

330 if not isinstance(s, str): 

331 s = qualname(s) 

332 if s in _serializers and _serializers[s] != name: 

333 raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}") 

334 log.debug("registering %s for serialization", s) 

335 _serializers[s] = name 

336 for d in getattr(name, "deserializers", ()): 

337 if not isinstance(d, str): 

338 d = qualname(d) 

339 if d in _deserializers and _deserializers[d] != name: 

340 raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}") 

341 log.debug("registering %s for deserialization", d) 

342 _deserializers[d] = name 

343 _extra_allowed.add(d) 

344 for c in getattr(name, "stringifiers", ()): 

345 if not isinstance(c, str): 

346 c = qualname(c) 

347 if c in _deserializers and _deserializers[c] != name: 

348 raise AttributeError(f"duplicate {c} for stringifiers in {name} and {_stringifiers[c]}") 

349 log.debug("registering %s for stringifying", c) 

350 _stringifiers[c] = name 

351 

352 log.debug("loading serializers took %.3f seconds", timer.duration) 

353 

354 

355@functools.lru_cache(maxsize=None) 

356def _get_patterns() -> list[re.Pattern]: 

357 patterns = conf.get("core", "allowed_deserialization_classes").split() 

358 return [re.compile(re.sub(r"(\w)\.", r"\1\..", p)) for p in patterns] 

359 

360 

361_register()