Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/serialization/serde.py: 40%

163 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import dataclasses 

21import enum 

22import logging 

23import re 

24import sys 

25from importlib import import_module 

26from types import ModuleType 

27from typing import Any, TypeVar, Union 

28 

29import attr 

30 

31import airflow.serialization.serializers 

32from airflow.configuration import conf 

33from airflow.utils.module_loading import import_string, iter_namespace, qualname 

34 

35log = logging.getLogger(__name__) 

36 

37MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1 

38 

39CLASSNAME = "__classname__" 

40VERSION = "__version__" 

41DATA = "__data__" 

42SCHEMA_ID = "__id__" 

43CACHE = "__cache__" 

44 

45OLD_TYPE = "__type" 

46OLD_SOURCE = "__source" 

47OLD_DATA = "__var" 

48 

49DEFAULT_VERSION = 0 

50 

51T = TypeVar("T", bool, float, int, dict, list, str, tuple, set) 

52U = Union[bool, float, int, dict, list, str, tuple, set] 

53S = Union[list, tuple, set] 

54 

55_serializers: dict[str, ModuleType] = {} 

56_deserializers: dict[str, ModuleType] = {} 

57_extra_allowed: set[str] = set() 

58 

59_primitives = (int, bool, float, str) 

60_iterables = (list, set, tuple) 

61_patterns: list[re.Pattern] = [] 

62 

63_reverse_cache: dict[int, tuple[ModuleType, str, int]] = {} 

64 

65 

66def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]: 

67 """Encodes o so it can be understood by the deserializer""" 

68 return {CLASSNAME: cls, VERSION: version, DATA: data} 

69 

70 

71def decode(d: dict[str, str | int | T]) -> tuple: 

72 return d[CLASSNAME], d[VERSION], d.get(DATA, None) 

73 

74 

75def serialize(o: object, depth: int = 0) -> U | None: 

76 """ 

77 Recursively serializes objects into a primitive. Primitives (int, float, int, bool) 

78 are returned as is. Tuples and dicts are iterated over, where it is assumed that keys 

79 for dicts can be represented as str. Values that are not primitives are serialized if 

80 a serializer is found for them. The order in which serializers are used 

81 is 1) a serialize function provided by the object 2) a registered serializer in 

82 the namespace of airflow.serialization.serializers and 3) an attr or dataclass annotations. 

83 If a serializer cannot be found a TypeError is raised. 

84 

85 :param o: object to serialize 

86 :param depth: private 

87 :return: a primitive 

88 """ 

89 if depth == MAX_RECURSION_DEPTH: 

90 raise RecursionError("maximum recursion depth reached for serialization") 

91 

92 # None remains None 

93 if o is None: 

94 return o 

95 

96 # primitive types are returned as is 

97 if isinstance(o, _primitives): 

98 if isinstance(o, enum.Enum): 

99 return o.value 

100 

101 return o 

102 

103 # tuples and plain dicts are iterated over recursively 

104 if isinstance(o, _iterables): 

105 s = [serialize(d, depth + 1) for d in o] 

106 if isinstance(o, tuple): 

107 return tuple(s) 

108 if isinstance(o, set): 

109 return set(s) 

110 return s 

111 

112 if isinstance(o, dict): 

113 if CLASSNAME in o or SCHEMA_ID in o: 

114 raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize") 

115 

116 return {str(k): serialize(v, depth + 1) for k, v in o.items()} 

117 

118 cls = type(o) 

119 qn = qualname(o) 

120 

121 # custom serializers 

122 dct = { 

123 CLASSNAME: qn, 

124 VERSION: getattr(cls, "__version__", DEFAULT_VERSION), 

125 } 

126 

127 # if there is a builtin serializer available use that 

128 if qn in _serializers: 

129 data, classname, version, is_serialized = _serializers[qn].serialize(o) 

130 if is_serialized: 

131 return encode(classname, version, serialize(data, depth + 1)) 

132 

133 # object / class brings their own 

134 if hasattr(o, "serialize"): 

135 data = getattr(o, "serialize")() 

136 

137 # if we end up with a structure, ensure its values are serialized 

138 if isinstance(data, dict): 

139 data = serialize(data, depth + 1) 

140 

141 dct[DATA] = data 

142 return dct 

143 

144 # dataclasses 

145 if dataclasses.is_dataclass(cls): 

146 data = dataclasses.asdict(o) 

147 dct[DATA] = serialize(data, depth + 1) 

148 return dct 

149 

150 # attr annotated 

151 if attr.has(cls): 

152 # Only include attributes which we can pass back to the classes constructor 

153 data = attr.asdict(o, recurse=True, filter=lambda a, v: a.init) # type: ignore[arg-type] 

154 dct[DATA] = serialize(data, depth + 1) 

155 return dct 

156 

157 raise TypeError(f"cannot serialize object of type {cls}") 

158 

159 

160def deserialize(o: T | None, full=True, type_hint: Any = None) -> object: 

161 """ 

162 Deserializes an object of primitive type T into an object. Uses an allow 

163 list to determine if a class can be loaded. 

164 

165 :param o: primitive to deserialize into an arbitrary object. 

166 :param full: if False it will return a stringified representation 

167 of an object and will not load any classes 

168 :param type_hint: if set it will be used to help determine what 

169 object to deserialize in. It does not override if another 

170 specification is found 

171 :return: object 

172 """ 

173 if o is None: 

174 return o 

175 

176 if isinstance(o, _primitives): 

177 return o 

178 

179 if isinstance(o, _iterables): 

180 return [deserialize(d) for d in o] 

181 

182 if not isinstance(o, dict): 

183 raise TypeError() 

184 

185 o = _convert(o) 

186 

187 # plain dict and no type hint 

188 if CLASSNAME not in o and not type_hint or VERSION not in o: 

189 return {str(k): deserialize(v, full) for k, v in o.items()} 

190 

191 # custom deserialization starts here 

192 cls: Any 

193 version = 0 

194 value: Any 

195 classname: str 

196 

197 if type_hint: 

198 cls = type_hint 

199 classname = qualname(cls) 

200 version = 0 # type hinting always sets version to 0 

201 value = o 

202 

203 if CLASSNAME in o and VERSION in o: 

204 classname, version, value = decode(o) 

205 if not _match(classname) and classname not in _extra_allowed: 

206 raise ImportError( 

207 f"{classname} was not found in allow list for deserialization imports." 

208 f"To allow it, add it to allowed_deserialization_classes in the configuration" 

209 ) 

210 

211 if full: 

212 cls = import_string(classname) 

213 

214 # only return string representation 

215 if not full: 

216 return _stringify(classname, version, value) 

217 

218 # registered deserializer 

219 if classname in _deserializers: 

220 return _deserializers[classname].deserialize(classname, version, deserialize(value)) 

221 

222 # class has deserialization function 

223 if hasattr(cls, "deserialize"): 

224 return getattr(cls, "deserialize")(deserialize(value), version) 

225 

226 # attr or dataclass 

227 if attr.has(cls) or dataclasses.is_dataclass(cls): 

228 class_version = getattr(cls, "__version__", 0) 

229 if int(version) > class_version: 

230 raise TypeError( 

231 "serialized version of %s is newer than module version (%s > %s)", 

232 classname, 

233 version, 

234 class_version, 

235 ) 

236 

237 return cls(**deserialize(value)) 

238 

239 # no deserializer available 

240 raise TypeError(f"No deserializer found for {classname}") 

241 

242 

243def _convert(old: dict) -> dict: 

244 """Converts an old style serialization to new style""" 

245 if OLD_TYPE in old and OLD_DATA in old: 

246 return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA][OLD_DATA]} 

247 

248 return old 

249 

250 

251def _match(classname: str) -> bool: 

252 for p in _patterns: 

253 if p.match(classname): 

254 return True 

255 

256 return False 

257 

258 

259def _stringify(classname: str, version: int, value: T | None) -> str: 

260 s = f"{classname}@version={version}(" 

261 if isinstance(value, _primitives): 

262 s += f"{value})" 

263 elif isinstance(value, _iterables): 

264 s += ",".join(str(serialize(value, False))) 

265 elif isinstance(value, dict): 

266 for k, v in value.items(): 

267 s += f"{k}={str(serialize(v, False))}," 

268 s = s[:-1] + ")" 

269 

270 return s 

271 

272 

273def _register(): 

274 """Register builtin serializers and deserializers for types that don't have any themselves""" 

275 _serializers.clear() 

276 _deserializers.clear() 

277 

278 for _, name, _ in iter_namespace(airflow.serialization.serializers): 

279 name = import_module(name) 

280 for s in getattr(name, "serializers", list()): 

281 if not isinstance(s, str): 

282 s = qualname(s) 

283 if s in _serializers and _serializers[s] != name: 

284 raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}") 

285 log.debug("registering %s for serialization") 

286 _serializers[s] = name 

287 for d in getattr(name, "deserializers", list()): 

288 if not isinstance(d, str): 

289 d = qualname(d) 

290 if d in _deserializers and _deserializers[d] != name: 

291 raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}") 

292 log.debug("registering %s for deserialization", d) 

293 _deserializers[d] = name 

294 _extra_allowed.add(d) 

295 

296 

297def _compile_patterns(): 

298 patterns = conf.get("core", "allowed_deserialization_classes").split() 

299 

300 _patterns.clear() # ensure to reinit 

301 for p in patterns: 

302 _patterns.append(re.compile(p)) 

303 

304 

305_register() 

306_compile_patterns()